[hlsl] Add ValueToLet transform.
This CL adds the ValueToLet transform into the HLSL IR backend.
Bug: 42251045
Change-Id: I63ca41e1a01c6e133fe54c6788190df9a896f6b7
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/195802
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/tint/lang/hlsl/writer/binary_test.cc b/src/tint/lang/hlsl/writer/binary_test.cc
index fefce70..6795ce1 100644
--- a/src/tint/lang/hlsl/writer/binary_test.cc
+++ b/src/tint/lang/hlsl/writer/binary_test.cc
@@ -122,7 +122,8 @@
ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
EXPECT_EQ(output_.hlsl, R"(
uint tint_mod_u32(uint lhs, uint rhs) {
- return (lhs - ((lhs / (((rhs == 0u)) ? (1u) : (rhs))) * (((rhs == 0u)) ? (1u) : (rhs))));
+ uint v = (((rhs == 0u)) ? (1u) : (rhs));
+ return (lhs - ((lhs / v) * v));
}
[numthreads(1, 1, 1)]
diff --git a/src/tint/lang/hlsl/writer/bitcast_test.cc b/src/tint/lang/hlsl/writer/bitcast_test.cc
index 81ad93d..d873ada 100644
--- a/src/tint/lang/hlsl/writer/bitcast_test.cc
+++ b/src/tint/lang/hlsl/writer/bitcast_test.cc
@@ -158,9 +158,10 @@
void foo() {
vector<float16_t, 2> a = vector<float16_t, 2>(float16_t(1.0h), float16_t(2.0h));
- int b = tint_bitcast_from_f16(a);
- float c = tint_bitcast_from_f16_1(a);
- uint d = tint_bitcast_from_f16_2(a);
+ vector<float16_t, 2> v = a;
+ int b = tint_bitcast_from_f16(v);
+ float c = tint_bitcast_from_f16_1(v);
+ uint d = tint_bitcast_from_f16_2(v);
}
)");
@@ -245,9 +246,10 @@
void foo() {
vector<float16_t, 4> a = vector<float16_t, 4>(float16_t(1.0h), float16_t(2.0h), float16_t(3.0h), float16_t(4.0h));
- int2 b = tint_bitcast_from_f16(a);
- float2 c = tint_bitcast_from_f16_1(a);
- uint2 d = tint_bitcast_from_f16_2(a);
+ vector<float16_t, 4> v = a;
+ int2 b = tint_bitcast_from_f16(v);
+ float2 c = tint_bitcast_from_f16_1(v);
+ uint2 d = tint_bitcast_from_f16_2(v);
}
)");
diff --git a/src/tint/lang/hlsl/writer/constant_test.cc b/src/tint/lang/hlsl/writer/constant_test.cc
index aff0f2a..29cfeee 100644
--- a/src/tint/lang/hlsl/writer/constant_test.cc
+++ b/src/tint/lang/hlsl/writer/constant_test.cc
@@ -878,7 +878,7 @@
};
auto* strct = ty.Struct(b.ir.symbols.New("S"), std::move(members));
- b.Append(b.ir.root_block, [&] { b.Var<private_>("p", b.Construct(strct, 3_i)); });
+ b.Append(b.ir.root_block, [&] { b.Var<private_>("p", b.Composite(strct, 3_i)); });
ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
EXPECT_EQ(output_.hlsl, R"(struct S {
diff --git a/src/tint/lang/hlsl/writer/function_test.cc b/src/tint/lang/hlsl/writer/function_test.cc
index ca481a5..0832503 100644
--- a/src/tint/lang/hlsl/writer/function_test.cc
+++ b/src/tint/lang/hlsl/writer/function_test.cc
@@ -192,31 +192,32 @@
auto* coord = b.FunctionParam("coord", ty.vec4<f32>());
coord->SetBuiltin(core::BuiltinValue::kPosition);
- auto* func = b.Function("frag_main", ty.f32());
+ auto* func = b.Function("frag_main", ty.f32(), core::ir::Function::PipelineStage::kFragment);
func->SetReturnBuiltin(core::BuiltinValue::kFragDepth);
func->SetParams({coord});
- b.Append(func->Block(), [&] { //
- auto* a = b.Access(ty.ptr(function, ty.f32()), coord, 0_u);
- b.Return(func, b.Load(a));
+ b.Append(func->Block(), [&] {
+ auto* a = b.Swizzle(ty.f32(), coord, {0});
+ b.Return(func, a);
});
ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
- EXPECT_EQ(output_.hlsl, R"(struct tint_symbol_1 {
+ EXPECT_EQ(output_.hlsl, R"(struct frag_main_outputs {
+ float tint_symbol : SV_Depth;
+};
+
+struct frag_main_inputs {
float4 coord : SV_Position;
};
-struct tint_symbol_2 {
- float value : SV_Depth;
-};
+
float frag_main_inner(float4 coord) {
return coord.x;
}
-tint_symbol_2 frag_main(tint_symbol_1 tint_symbol) {
- float inner_result = frag_main_inner(float4(tint_symbol.coord.xyz, (1.0f / tint_symbol.coord.w)));
- tint_symbol_2 wrapper_result = (tint_symbol_2)0;
- wrapper_result.value = inner_result;
- return wrapper_result;
+frag_main_outputs frag_main(frag_main_inputs inputs) {
+ float inner_result = frag_main_inner(float4(inputs.coord.xyz, (1.0f / inputs.coord.w)));
+ frag_main_outputs v = {inner_result};
+ return v;
}
)");
@@ -266,13 +267,13 @@
b.Function("frag_main", ty.void_(), core::ir::Function::PipelineStage::kFragment);
frag_func->SetParams({frag_param});
b.Append(frag_func->Block(), [&] {
- auto* r = b.Access(ty.ptr(function, ty.f32()), frag_param, 1_u);
- auto* g = b.Access(ty.ptr(function, ty.f32()), frag_param, 2_u);
- auto* p = b.Access(ty.ptr(function, ty.vec4<f32>()), frag_param, 0_u);
+ auto* r = b.Access(ty.f32(), frag_param, 1_u);
+ auto* g = b.Access(ty.f32(), frag_param, 2_u);
+ auto* p = b.Access(ty.vec4<f32>(), frag_param, 0_u);
- b.Let("r", b.Load(r));
- b.Let("g", b.Load(g));
- b.Let("p", b.Load(p));
+ b.Let("r", r);
+ b.Let("g", g);
+ b.Let("p", p);
b.Return(frag_func);
});
@@ -282,42 +283,40 @@
float col1;
float col2;
};
-struct tint_symbol {
- float col1 : TEXCOORD1;
- float col2 : TEXCOORD2;
- float4 pos : SV_Position;
+
+struct vert_main_outputs {
+ float Interface_col1 : TEXCOORD1;
+ float Interface_col2 : TEXCOORD2;
+ float4 Interface_pos : SV_Position;
};
+struct frag_main_inputs {
+ float Interface_col1 : TEXCOORD1;
+ float Interface_col2 : TEXCOORD2;
+ float4 Interface_pos : SV_Position;
+};
+
+
Interface vert_main_inner() {
Interface tint_symbol_3 = {(0.0f).xxxx, 0.5f, 0.25f};
return tint_symbol_3;
}
-tint_symbol vert_main() {
- Interface inner_result = vert_main_inner();
- tint_symbol wrapper_result = (tint_symbol)0;
- wrapper_result.pos = inner_result.pos;
- wrapper_result.col1 = inner_result.col1;
- wrapper_result.col2 = inner_result.col2;
- return wrapper_result;
-}
-
-struct tint_symbol_2 {
- float col1 : TEXCOORD1;
- float col2 : TEXCOORD2;
- float4 pos : SV_Position;
-};
-
void frag_main_inner(Interface inputs) {
float r = inputs.col1;
float g = inputs.col2;
float4 p = inputs.pos;
}
-void frag_main(tint_symbol_2 tint_symbol_1) {
- Interface tint_symbol_4 = {float4(tint_symbol_1.pos.xyz, (1.0f / tint_symbol_1.pos.w)), tint_symbol_1.col1, tint_symbol_1.col2};
- frag_main_inner(tint_symbol_4);
- return;
+vert_main_outputs vert_main() {
+ Interface v = vert_main_inner();
+ vert_main_outputs wrapper_result = {v.col1, v.col2, v.pos};
+ return wrapper_result;
+}
+
+void frag_main(frag_main_inputs inputs) {
+ Interface v = {float4(inputs.pos.xyz, (1.0f / inputs.pos.w)), inputs.col1, inputs.col2};
+ frag_main_inner(v);
}
)");
@@ -372,39 +371,37 @@
float4 pos;
};
-VertexOutput foo(float x) {
- VertexOutput tint_symbol_2 = {float4(x, x, x, 1.0f)};
- return tint_symbol_2;
-}
-
-struct tint_symbol {
- float4 pos : SV_Position;
+struct vert1_main1_outputs {
+ float4 VertexOutput_pos : SV_Position;
};
-VertexOutput vert_main1_inner() {
+struct vert2_main1_outputs {
+ float4 VertexOutput_pos : SV_Position;
+};
+
+
+VertexOutput foo(float x) {
+ return {float4(x, x, x, 1.0f)};
+}
+
+VertexOutput vert1_main1_inner() {
return foo(0.5f);
}
-tint_symbol vert_main1() {
- VertexOutput inner_result = vert_main1_inner();
- tint_symbol wrapper_result = (tint_symbol)0;
- wrapper_result.pos = inner_result.pos;
- return wrapper_result;
-}
-
-struct tint_symbol_1 {
- float4 pos : SV_Position;
-};
-
-VertexOutput vert_main2_inner() {
+VertexOutput vert2_main1_inner() {
return foo(0.25f);
}
-tint_symbol_1 vert_main2() {
+vert1_main1_outputs vert1_main1() {
+ VertexOutput inner_result = vert_main1_inner();
+ vert1_main1_outputs v = {inner_result.pos};
+ return v;
+}
+
+vert2_main1_outputs vert2_main1() {
VertexOutput inner_result_1 = vert_main2_inner();
- tint_symbol_1 wrapper_result_1 = (tint_symbol_1)0;
- wrapper_result_1.pos = inner_result_1.pos;
- return wrapper_result_1;
+ vert2_main1_outputs v = {inner_result_1.pos};
+ return v;
}
)");
@@ -423,8 +420,12 @@
// var v = sub_func(1f);
// }
- Vector members{ty.Get<core::type::StructMember>(b.ir.symbols.New("coord"), ty.vec4<f32>(), 0u,
- 0u, 16u, 16u,
+ Vector inner_members{ty.Get<core::type::StructMember>(
+ b.ir.symbols.New("coord"), ty.f32(), 0u, 0u, 4u, 4u, core::type::StructMemberAttributes{})};
+ auto* inner_strct = ty.Struct(b.ir.symbols.New("Inner"), std::move(inner_members));
+
+ Vector members{ty.Get<core::type::StructMember>(b.ir.symbols.New("coord"), inner_strct, 0u, 0u,
+ 16u, 16u,
core::type::StructMemberAttributes{})};
auto* strct = ty.Struct(b.ir.symbols.New("Uniforms"), std::move(members));
@@ -449,7 +450,8 @@
});
ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
- EXPECT_EQ(output_.hlsl, R"(cbuffer cbuffer_uniforms : register(b0, space1) {
+ EXPECT_EQ(output_.hlsl, R"(
+cbuffer cbuffer_ubo : register(b0, space1) {
uint4 ubo[1];
};
@@ -459,7 +461,6 @@
void frag_main() {
float v = sub_func(1.0f);
- return;
}
)");
@@ -476,8 +477,12 @@
// var v = ubo.coord.x;
// }
- Vector members{ty.Get<core::type::StructMember>(b.ir.symbols.New("coord"), ty.vec4<f32>(), 0u,
- 0u, 16u, 16u,
+ Vector inner_members{ty.Get<core::type::StructMember>(
+ b.ir.symbols.New("coord"), ty.f32(), 0u, 0u, 4u, 4u, core::type::StructMemberAttributes{})};
+ auto* inner_strct = ty.Struct(b.ir.symbols.New("Inner"), std::move(inner_members));
+
+ Vector members{ty.Get<core::type::StructMember>(b.ir.symbols.New("coord"), inner_strct, 0u, 0u,
+ 16u, 16u,
core::type::StructMemberAttributes{})};
auto* strct = ty.Struct(b.ir.symbols.New("Uniforms"), std::move(members));
@@ -493,10 +498,10 @@
});
ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
- EXPECT_EQ(output_.hlsl, R"(cbuffer cbuffer_uniforms : register(b0, space1) {
+ EXPECT_EQ(output_.hlsl, R"(
+cbuffer cbuffer_ubo : register(b0, space1) {
uint4 ubo[1];
};
-
void frag_main() {
float v = asfloat(ubo[0].x);
return;
@@ -528,7 +533,7 @@
auto* func = b.Function("frag_main", ty.void_(), core::ir::Function::PipelineStage::kFragment);
b.Append(func->Block(), [&] { //
- auto* a = b.Access(ty.ptr(storage, ty.f32()), coord, 0_u);
+ auto* a = b.Access(ty.ptr(storage, ty.i32()), coord, 0_u);
b.Var("v", b.Load(a));
b.Return(func);
@@ -536,10 +541,10 @@
ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
EXPECT_EQ(output_.hlsl,
- R"(RWByteAddressBuffer coord : register(u0, space1);
-
+ R"(
+RWByteAddressBuffer coord : register(u0, space1);
void frag_main() {
- float v = asfloat(coord.Load(4u));
+ int v = asint(coord.Load(4u));
return;
}
@@ -569,7 +574,7 @@
auto* func = b.Function("frag_main", ty.void_(), core::ir::Function::PipelineStage::kFragment);
b.Append(func->Block(), [&] { //
- auto* a = b.Access(ty.ptr<storage, f32>(), coord, 0_u);
+ auto* a = b.Access(ty.ptr<storage, i32, core::Access::kRead>(), coord, 0_u);
b.Var("v", b.Load(a));
b.Return(func);
});
@@ -579,7 +584,7 @@
R"(ByteAddressBuffer coord : register(t0, space1);
void frag_main() {
- float v = asfloat(coord.Load(4u));
+ int v = asint(coord.Load(4u));
return;
}
@@ -603,13 +608,14 @@
core::type::StructMemberAttributes{})};
auto* strct = ty.Struct(b.ir.symbols.New("Data"), std::move(members));
- auto* coord = b.Var("coord", storage, strct, core::Access::kWrite);
+ auto* coord = b.Var("coord", storage, strct, core::Access::kReadWrite);
coord->SetBindingPoint(1, 0);
b.ir.root_block->Append(coord);
auto* func = b.Function("frag_main", ty.void_(), core::ir::Function::PipelineStage::kFragment);
b.Append(func->Block(), [&] { //
b.Store(b.Access(ty.ptr(storage, ty.f32()), coord, 1_u), 2_f);
+ b.Return(func);
});
ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
@@ -648,12 +654,13 @@
auto* func = b.Function("frag_main", ty.void_(), core::ir::Function::PipelineStage::kFragment);
b.Append(func->Block(), [&] { //
b.Store(b.Access(ty.ptr(storage, ty.f32()), coord, 1_u), 2_f);
+ b.Return(func);
});
ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
EXPECT_EQ(output_.hlsl,
- R"(RWByteAddressBuffer coord : register(u0, space1);
-
+ R"(
+RWByteAddressBuffer coord : register(u0, space1);
void frag_main() {
coord.Store(4u, asuint(2.0f));
return;
@@ -685,7 +692,7 @@
auto* sub_func = b.Function("sub_func", ty.f32());
b.Append(sub_func->Block(), [&] {
- auto* a = b.Access(ty.ptr<storage, f32>(), coord, 0_u);
+ auto* a = b.Access(ty.ptr<uniform, f32, core::Access::kRead>(), coord, 0_u);
b.Return(sub_func, b.Load(a));
});
@@ -722,7 +729,7 @@
// return coord.x;
// }
// @fragment fn frag_main() {
- // var v = sub_func(1f);
+ // var v = sub_func();
// }
Vector members{ty.Get<core::type::StructMember>(b.ir.symbols.New("x"), ty.f32(), 0u, 0u, 4u, 4u,
@@ -741,21 +748,20 @@
auto* func = b.Function("frag_main", ty.void_(), core::ir::Function::PipelineStage::kFragment);
b.Append(func->Block(), [&] { //
- b.Var("v", b.Call(sub_func, 1_f));
+ b.Var("v", b.Call(sub_func));
b.Return(func);
});
ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
EXPECT_EQ(output_.hlsl,
- R"(RWByteAddressBuffer coord : register(u0, space1);
-
-float sub_func(float param) {
+ R"(
+RWByteAddressBuffer coord : register(u0, space1);
+float sub_func() {
return asfloat(coord.Load(0u));
}
void frag_main() {
- float v = sub_func(1.0f);
- return;
+ float v = sub_func();
}
)");
@@ -793,7 +799,7 @@
)");
}
-TEST_F(HlslWriterTest, DISABLED_FunctionWithArrayParams) {
+TEST_F(HlslWriterTest, FunctionWithArrayParams) {
// fn my_func(a: array<f32, 5>) {}
auto* func = b.Function("my_func", ty.void_());
@@ -802,8 +808,8 @@
func->Block()->Append(b.Return(func));
ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
- EXPECT_EQ(output_.hlsl, R"(void my_func(float a[5]) {
- return;
+ EXPECT_EQ(output_.hlsl, R"(
+void my_func(float a[5]) {
}
[numthreads(1, 1, 1)]
@@ -962,7 +968,6 @@
ASSERT_TRUE(Generate()) << err_ << output_.hlsl;
EXPECT_EQ(output_.hlsl, R"(
RWByteAddressBuffer data : register(u0);
-
[numthreads(1, 1, 1)]
void a() {
float v = asfloat(data.Load(0u));
diff --git a/src/tint/lang/hlsl/writer/raise/raise.cc b/src/tint/lang/hlsl/writer/raise/raise.cc
index c7055e0..98a7027 100644
--- a/src/tint/lang/hlsl/writer/raise/raise.cc
+++ b/src/tint/lang/hlsl/writer/raise/raise.cc
@@ -30,6 +30,7 @@
#include "src/tint/lang/core/ir/transform/add_empty_entry_point.h"
#include "src/tint/lang/core/ir/transform/binary_polyfill.h"
#include "src/tint/lang/core/ir/transform/remove_terminator_args.h"
+#include "src/tint/lang/core/ir/transform/value_to_let.h"
#include "src/tint/lang/hlsl/writer/common/options.h"
#include "src/tint/lang/hlsl/writer/raise/builtin_polyfill.h"
#include "src/tint/lang/hlsl/writer/raise/fxc_polyfill.h"
@@ -66,6 +67,7 @@
// These transforms need to be run last as various transforms introduce terminator arguments,
// naming conflicts, and expressions that need to be explicitly not inlined.
RUN_TRANSFORM(core::ir::transform::RemoveTerminatorArgs);
+ RUN_TRANSFORM(core::ir::transform::ValueToLet);
return Success;
}