blob: ecd691e77a071c3e061dd24f3e0f69b302d32020 [file] [log] [blame]
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/ast/transform/truncate_interstage_variables.h"
#include "src/tint/ast/transform/canonicalize_entry_point_io.h"
#include "gmock/gmock.h"
#include "src/tint/ast/transform/test_helper.h"
namespace tint::ast::transform {
namespace {
using ::testing::ContainerEq;
using TruncateInterstageVariablesTest = TransformTest;
TEST_F(TruncateInterstageVariablesTest, ShouldRunVertex) {
auto* src = R"(
struct ShaderIO {
@builtin(position) pos: vec4<f32>,
@location(0) f_0: f32,
@location(2) f_2: f32,
}
@vertex
fn f() -> ShaderIO {
var io: ShaderIO;
io.f_0 = 1.0;
io.f_2 = io.f_2 + 3.0;
return io;
}
)";
{
auto* expect =
"error: missing transform data for tint::ast::transform::TruncateInterstageVariables";
auto got = Run<TruncateInterstageVariables>(src);
EXPECT_EQ(expect, str(got));
}
{
// Empty interstage_locations: truncate all interstage variables, should run
TruncateInterstageVariables::Config cfg;
Transform::DataMap data;
data.Add<TruncateInterstageVariables::Config>(cfg);
EXPECT_TRUE(ShouldRun<TruncateInterstageVariables>(src, data));
}
{
// All existing interstage_locations are marked: should not run
TruncateInterstageVariables::Config cfg;
cfg.interstage_locations[0] = true;
cfg.interstage_locations[2] = true;
Transform::DataMap data;
data.Add<TruncateInterstageVariables::Config>(cfg);
EXPECT_FALSE(ShouldRun<TruncateInterstageVariables>(src, data));
}
{
// Partial interstage_locations are marked: should run
TruncateInterstageVariables::Config cfg;
cfg.interstage_locations[2] = true;
Transform::DataMap data;
data.Add<TruncateInterstageVariables::Config>(cfg);
EXPECT_TRUE(ShouldRun<TruncateInterstageVariables>(src, data));
}
}
TEST_F(TruncateInterstageVariablesTest, ShouldRunFragment) {
auto* src = R"(
struct ShaderIO {
@location(0) f_0: f32,
@location(2) f_2: f32,
}
@fragment
fn f(io: ShaderIO) -> @location(1) vec4<f32> {
return vec4<f32>(io.f_0, io.f_2, 0.0, 1.0);
}
)";
TruncateInterstageVariables::Config cfg;
cfg.interstage_locations[2] = true;
Transform::DataMap data;
data.Add<TruncateInterstageVariables::Config>(cfg);
EXPECT_FALSE(ShouldRun<TruncateInterstageVariables>(src, data));
}
// Test that this transform should run after canoicalize entry point io, where shader io is already
// grouped into a struct.
TEST_F(TruncateInterstageVariablesTest, ShouldRunAfterCanonicalizeEntryPointIO) {
auto* src = R"(
@vertex
fn f() -> @builtin(position) vec4<f32> {
return vec4<f32>(1.0, 1.0, 1.0, 1.0);
}
)";
TruncateInterstageVariables::Config cfg;
cfg.interstage_locations[0] = true;
Transform::DataMap data;
data.Add<TruncateInterstageVariables::Config>(cfg);
data.Add<CanonicalizeEntryPointIO::Config>(CanonicalizeEntryPointIO::ShaderStyle::kHlsl);
auto got = Run<CanonicalizeEntryPointIO>(src, data);
// Inevitably entry point can write only one variable if not using struct
// So the truncate won't run.
EXPECT_FALSE(ShouldRun<TruncateInterstageVariables>(str(got), data));
}
TEST_F(TruncateInterstageVariablesTest, BasicVertexTrimLocationInMid) {
auto* src = R"(
struct ShaderIO {
@builtin(position) pos: vec4<f32>,
@location(1) f_1: f32,
@location(3) f_3: f32,
}
@vertex
fn f() -> ShaderIO {
var io: ShaderIO;
io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
io.f_1 = 1.0;
io.f_3 = io.f_1 + 3.0;
return io;
}
)";
auto* expect = R"(
struct tint_symbol {
@builtin(position)
pos : vec4<f32>,
@location(1)
f_1 : f32,
}
fn truncate_shader_output(io : ShaderIO) -> tint_symbol {
return tint_symbol(io.pos, io.f_1);
}
struct ShaderIO {
pos : vec4<f32>,
f_1 : f32,
f_3 : f32,
}
@vertex
fn f() -> tint_symbol {
var io : ShaderIO;
io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
io.f_1 = 1.0;
io.f_3 = (io.f_1 + 3.0);
return truncate_shader_output(io);
}
)";
TruncateInterstageVariables::Config cfg;
// fragment has input at @location(1)
cfg.interstage_locations[1] = true;
Transform::DataMap data;
data.Add<TruncateInterstageVariables::Config>(cfg);
auto got = Run<TruncateInterstageVariables>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(TruncateInterstageVariablesTest, BasicVertexTrimLocationAtEnd) {
auto* src = R"(
struct ShaderIO {
@builtin(position) pos: vec4<f32>,
@location(1) f_1: f32,
@location(3) f_3: f32,
}
@vertex
fn f() -> ShaderIO {
var io: ShaderIO;
io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
io.f_1 = 1.0;
io.f_3 = io.f_1 + 3.0;
return io;
}
)";
auto* expect = R"(
struct tint_symbol {
@builtin(position)
pos : vec4<f32>,
@location(3)
f_3 : f32,
}
fn truncate_shader_output(io : ShaderIO) -> tint_symbol {
return tint_symbol(io.pos, io.f_3);
}
struct ShaderIO {
pos : vec4<f32>,
f_1 : f32,
f_3 : f32,
}
@vertex
fn f() -> tint_symbol {
var io : ShaderIO;
io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
io.f_1 = 1.0;
io.f_3 = (io.f_1 + 3.0);
return truncate_shader_output(io);
}
)";
TruncateInterstageVariables::Config cfg;
// fragment has input at @location(3)
cfg.interstage_locations[3] = true;
Transform::DataMap data;
data.Add<TruncateInterstageVariables::Config>(cfg);
auto got = Run<TruncateInterstageVariables>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(TruncateInterstageVariablesTest, TruncateAllLocations) {
auto* src = R"(
struct ShaderIO {
@builtin(position) pos: vec4<f32>,
@location(1) f_1: f32,
@location(3) f_3: f32,
}
@vertex
fn f() -> ShaderIO {
var io: ShaderIO;
io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
io.f_1 = 1.0;
io.f_3 = io.f_1 + 3.0;
return io;
}
)";
{
auto* expect = R"(
struct tint_symbol {
@builtin(position)
pos : vec4<f32>,
}
fn truncate_shader_output(io : ShaderIO) -> tint_symbol {
return tint_symbol(io.pos);
}
struct ShaderIO {
pos : vec4<f32>,
f_1 : f32,
f_3 : f32,
}
@vertex
fn f() -> tint_symbol {
var io : ShaderIO;
io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
io.f_1 = 1.0;
io.f_3 = (io.f_1 + 3.0);
return truncate_shader_output(io);
}
)";
TruncateInterstageVariables::Config cfg;
Transform::DataMap data;
data.Add<TruncateInterstageVariables::Config>(cfg);
auto got = Run<TruncateInterstageVariables>(src, data);
EXPECT_EQ(expect, str(got));
}
}
// Test that the transform only removes IO attributes and preserve other attributes in the old
// Shader IO struct.
TEST_F(TruncateInterstageVariablesTest, RemoveIOAttributes) {
auto* src = R"(
struct ShaderIO {
@builtin(position) @invariant pos: vec4<f32>,
@location(1) f_1: f32,
@location(3) @align(16) f_3: f32,
@location(5) @interpolate(flat) @align(16) @size(16) f_5: u32,
}
@vertex
fn f() -> ShaderIO {
var io: ShaderIO;
io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
io.f_1 = 1.0;
io.f_3 = io.f_1 + 3.0;
io.f_5 = 1u;
return io;
}
)";
{
auto* expect = R"(
struct tint_symbol {
@builtin(position) @invariant
pos : vec4<f32>,
@location(3) @align(16)
f_3 : f32,
@location(5) @interpolate(flat) @align(16) @size(16)
f_5 : u32,
}
fn truncate_shader_output(io : ShaderIO) -> tint_symbol {
return tint_symbol(io.pos, io.f_3, io.f_5);
}
struct ShaderIO {
pos : vec4<f32>,
f_1 : f32,
@align(16)
f_3 : f32,
@align(16) @size(16)
f_5 : u32,
}
@vertex
fn f() -> tint_symbol {
var io : ShaderIO;
io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
io.f_1 = 1.0;
io.f_3 = (io.f_1 + 3.0);
io.f_5 = 1u;
return truncate_shader_output(io);
}
)";
TruncateInterstageVariables::Config cfg;
// Missing @location[1] intentionally to make sure the transform run.
cfg.interstage_locations[3] = true;
cfg.interstage_locations[5] = true;
Transform::DataMap data;
data.Add<TruncateInterstageVariables::Config>(cfg);
auto got = Run<TruncateInterstageVariables>(src, data);
EXPECT_EQ(expect, str(got));
}
}
TEST_F(TruncateInterstageVariablesTest, MultipleEntryPointsSharingStruct) {
auto* src = R"(
struct ShaderIO {
@builtin(position) pos: vec4<f32>,
@location(1) f_1: f32,
@location(3) f_3: f32,
@location(5) f_5: f32,
}
@vertex
fn f1() -> ShaderIO {
var io: ShaderIO;
io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
io.f_1 = 1.0;
io.f_3 = 1.0;
return io;
}
@vertex
fn f2() -> ShaderIO {
var io: ShaderIO;
io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
io.f_5 = 2.0;
return io;
}
)";
auto* expect = R"(
struct tint_symbol {
@builtin(position)
pos : vec4<f32>,
@location(3)
f_3 : f32,
}
fn truncate_shader_output(io : ShaderIO) -> tint_symbol {
return tint_symbol(io.pos, io.f_3);
}
struct ShaderIO {
pos : vec4<f32>,
f_1 : f32,
f_3 : f32,
f_5 : f32,
}
@vertex
fn f1() -> tint_symbol {
var io : ShaderIO;
io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
io.f_1 = 1.0;
io.f_3 = 1.0;
return truncate_shader_output(io);
}
@vertex
fn f2() -> tint_symbol {
var io : ShaderIO;
io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
io.f_5 = 2.0;
return truncate_shader_output(io);
}
)";
TruncateInterstageVariables::Config cfg;
// fragment has input at @location(3)
cfg.interstage_locations[3] = true;
Transform::DataMap data;
data.Add<TruncateInterstageVariables::Config>(cfg);
auto got = Run<TruncateInterstageVariables>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(TruncateInterstageVariablesTest, MultipleEntryPoints) {
auto* src = R"(
struct ShaderIO1 {
@builtin(position) pos: vec4<f32>,
@location(1) f_1: f32,
@location(3) f_3: f32,
@location(5) f_5: f32,
}
@vertex
fn f1() -> ShaderIO1 {
var io: ShaderIO1;
io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
io.f_1 = 1.0;
io.f_3 = 1.0;
return io;
}
struct ShaderIO2 {
@builtin(position) pos: vec4<f32>,
@location(2) f_2: f32,
}
@vertex
fn f2() -> ShaderIO2 {
var io: ShaderIO2;
io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
io.f_2 = 2.0;
return io;
}
)";
auto* expect = R"(
struct tint_symbol {
@builtin(position)
pos : vec4<f32>,
@location(3)
f_3 : f32,
}
fn truncate_shader_output(io : ShaderIO1) -> tint_symbol {
return tint_symbol(io.pos, io.f_3);
}
struct tint_symbol_1 {
@builtin(position)
pos : vec4<f32>,
}
fn truncate_shader_output_1(io : ShaderIO2) -> tint_symbol_1 {
return tint_symbol_1(io.pos);
}
struct ShaderIO1 {
pos : vec4<f32>,
f_1 : f32,
f_3 : f32,
f_5 : f32,
}
@vertex
fn f1() -> tint_symbol {
var io : ShaderIO1;
io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
io.f_1 = 1.0;
io.f_3 = 1.0;
return truncate_shader_output(io);
}
struct ShaderIO2 {
pos : vec4<f32>,
f_2 : f32,
}
@vertex
fn f2() -> tint_symbol_1 {
var io : ShaderIO2;
io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
io.f_2 = 2.0;
return truncate_shader_output_1(io);
}
)";
TruncateInterstageVariables::Config cfg;
// fragment has input at @location(3)
cfg.interstage_locations[3] = true;
Transform::DataMap data;
data.Add<TruncateInterstageVariables::Config>(cfg);
auto got = Run<TruncateInterstageVariables>(src, data);
EXPECT_EQ(expect, str(got));
}
TEST_F(TruncateInterstageVariablesTest, MultipleReturnStatements) {
auto* src = R"(
struct ShaderIO {
@builtin(position) pos: vec4<f32>,
@location(1) f_1: f32,
@location(3) f_3: f32,
}
@vertex
fn f(@builtin(vertex_index) vid: u32) -> ShaderIO {
var io: ShaderIO;
if (vid > 10u) {
io.f_1 = 2.0;
return io;
}
io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
io.f_1 = 1.0;
io.f_3 = io.f_1 + 3.0;
return io;
}
)";
auto* expect = R"(
struct tint_symbol {
@builtin(position)
pos : vec4<f32>,
@location(3)
f_3 : f32,
}
fn truncate_shader_output(io : ShaderIO) -> tint_symbol {
return tint_symbol(io.pos, io.f_3);
}
struct ShaderIO {
pos : vec4<f32>,
f_1 : f32,
f_3 : f32,
}
@vertex
fn f(@builtin(vertex_index) vid : u32) -> tint_symbol {
var io : ShaderIO;
if ((vid > 10u)) {
io.f_1 = 2.0;
return truncate_shader_output(io);
}
io.pos = vec4<f32>(1.0, 1.0, 1.0, 1.0);
io.f_1 = 1.0;
io.f_3 = (io.f_1 + 3.0);
return truncate_shader_output(io);
}
)";
TruncateInterstageVariables::Config cfg;
// fragment has input at @location(3)
cfg.interstage_locations[3] = true;
Transform::DataMap data;
data.Add<TruncateInterstageVariables::Config>(cfg);
auto got = Run<TruncateInterstageVariables>(src, data);
EXPECT_EQ(expect, str(got));
}
} // namespace
} // namespace tint::ast::transform