blob: e5fa158eb596c42da06ce012c9fd8967d0673d4f [file] [log] [blame]
package androidx.webgpu
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.filters.SmallTest
import androidx.webgpu.helper.ValidationException
import androidx.webgpu.helper.WebGpu
import androidx.webgpu.helper.createWebGpu
import junit.framework.TestCase.assertEquals
import kotlinx.coroutines.runBlocking
import org.junit.After
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.Assert.assertThrows
import org.junit.Assume
import org.junit.Before
@Suppress("UNUSED_VARIABLE")
@RunWith(AndroidJUnit4::class)
@SmallTest
class DeviceTest {
private lateinit var device: GPUDevice
private lateinit var webGpu: WebGpu
@Before
fun setup() = runBlocking {
webGpu = createWebGpu()
device = webGpu.device
}
@After
fun teardown() {
runCatching { device.destroy() }
webGpu.close()
}
@Test
@SmallTest
fun testHasFeature() {
runBlocking {
// This test ensures the API is callable.
Assume.assumeTrue(device.hasFeature(FeatureName.TimestampQuery))
}
}
@Test
@SmallTest
fun testErrorScope() {
device.pushErrorScope(ErrorFilter.Validation)
// Intentionally create an invalid buffer to trigger a validation error.
// A buffer size must be a multiple of 4.
val unusedBuffer = device.createBuffer(
BufferDescriptor(
size = 1, usage = BufferUsage.Vertex, mappedAtCreation = true
)
)
val error = runBlocking { device.popErrorScope() }
assertEquals(error.type, ErrorType.Validation)
assertEquals(error.status, PopErrorScopeStatus.Success)
}
@Test
@SmallTest
fun testCreateBuffer() {
val buffer = device.createBuffer(
BufferDescriptor(
size = 4, usage = BufferUsage.Vertex
)
)
assertEquals(buffer.usage, BufferUsage.Vertex)
}
@Test
@SmallTest
fun testCreateTexture() {
val texture = device.createTexture(
TextureDescriptor(
size = Extent3D(1, 1, 1),
format = TextureFormat.RGBA8Unorm,
usage = TextureUsage.TextureBinding
)
)
assertEquals(texture.usage, TextureUsage.TextureBinding)
}
@Test
@SmallTest
fun testCreateComputePipeline_withInvalidEntryPoint_throwsException() {
val shaderModule = device.createShaderModule(
ShaderModuleDescriptor(
shaderSourceWGSL = ShaderSourceWGSL(
code = "@compute @workgroup_size(1) fn main() {}"
)
)
)
assertThrows(ValidationException::class.java) {
device.createComputePipeline(
ComputePipelineDescriptor(
compute = ComputeState(
module = shaderModule, entryPoint = "non_existent_entry_point"
)
)
)
}
}
@Test
@SmallTest
fun testCreateShaderModule_withInvalidShader_throwsException() {
// This shader has a syntax error ("fu" instead of "fn")
val badShaderCode = "@compute @workgroup_size(1) fu main() {}"
// Creating the shader module itself should fail
assertThrows(ValidationException::class.java) {
device.createShaderModule(
ShaderModuleDescriptor(shaderSourceWGSL = ShaderSourceWGSL(code = badShaderCode))
)
}
}
/**
* Verifies that createRenderPipeline fails validation if the entry point is incorrect.
*/
@Test
@SmallTest
fun testCreateRenderPipeline_withInvalidEntryPoint_failsValidation() {
val shaderModule = device.createShaderModule(
ShaderModuleDescriptor(
shaderSourceWGSL = ShaderSourceWGSL(
code = "@vertex fn main() -> @builtin(position) vec4<f32> { return vec4<f32>(0.0); }"
)
)
)
device.pushErrorScope(ErrorFilter.Validation)
val unusedRenderPipeline = device.createRenderPipeline(
RenderPipelineDescriptor(
vertex = VertexState(
module = shaderModule, entryPoint = "non_existent_entry_point" // Invalid
)
)
)
val error = runBlocking { device.popErrorScope() }
assertEquals(error.type, ErrorType.Validation)
assertEquals(error.status, PopErrorScopeStatus.Success)
}
@Test
@SmallTest
fun testCreateBindGroupLayout_withDuplicateBindings_failsValidation() {
device.pushErrorScope(ErrorFilter.Validation)
val unusedBindGroupLayout = device.createBindGroupLayout(
BindGroupLayoutDescriptor(
entries = arrayOf(
BindGroupLayoutEntry(
binding = 0, // Duplicate
visibility = ShaderStage.Fragment,
buffer = BufferBindingLayout(type = BufferBindingType.Storage)
), BindGroupLayoutEntry(
binding = 0, // Duplicate
visibility = ShaderStage.Fragment,
buffer = BufferBindingLayout(type = BufferBindingType.Storage)
)
)
)
)
val error = runBlocking { device.popErrorScope() }
assertEquals(error.type, ErrorType.Validation)
assertEquals(error.status, PopErrorScopeStatus.Success)
}
/**
* Verifies that createBindGroup fails validation if the buffer's usage is incorrect.
*/
@Test
@SmallTest
fun testCreateBindGroup_withMismatchedBufferUsage_failsValidation() {
val layout = device.createBindGroupLayout(
BindGroupLayoutDescriptor(
entries = arrayOf(
BindGroupLayoutEntry(
binding = 0,
visibility = ShaderStage.Compute,
buffer = BufferBindingLayout(type = BufferBindingType.Uniform)
)
)
)
)
// Create a buffer WITHOUT the required `Uniform` usage.
val buffer = device.createBuffer(
BufferDescriptor(size = 16, usage = BufferUsage.CopySrc) // Invalid usage
)
device.pushErrorScope(ErrorFilter.Validation)
val unusedBindGroup = device.createBindGroup(
BindGroupDescriptor(
layout = layout, entries = arrayOf(
BindGroupEntry(binding = 0, buffer = buffer)
)
)
)
val error = runBlocking { device.popErrorScope() }
assertEquals(error.type, ErrorType.Validation)
assertEquals(error.status, PopErrorScopeStatus.Success)
}
@Test
@SmallTest
fun testCreateQuerySet_withInvalidCount_failsValidation() {
device.pushErrorScope(ErrorFilter.Validation)
val unusedQuerySet = device.createQuerySet(
QuerySetDescriptor(
type = QueryType.Occlusion, count = -1 // Invalid: count must be > 0.
)
)
val error = runBlocking { device.popErrorScope() }
assertEquals(error.type, ErrorType.Validation)
assertEquals(error.status, PopErrorScopeStatus.Success)
}
@Test
fun validationError_withoutActiveErrorScope_throwsValidationException() {
val invalidDescriptor = QuerySetDescriptor(
type = QueryType.Occlusion,
count = -1 // Invalid parameter
)
assertThrows(ValidationException::class.java) {
device.createQuerySet(invalidDescriptor)
}
}
/**
* Verifies that createSampler fails validation with an invalid descriptor.
*/
@Test
@SmallTest
fun testCreateSampler_withInvalidLodClamp_failsValidation() = runBlocking {
val invalidDescriptor = SamplerDescriptor(
lodMinClamp = 10.0f, lodMaxClamp = 1.0f // Invalid: min cannot be greater than max.
)
device.pushErrorScope(ErrorFilter.Validation)
val unusedSampler = device.createSampler(invalidDescriptor)
val error = device.popErrorScope()
assertEquals(error.type, ErrorType.Validation)
assertEquals(error.status, PopErrorScopeStatus.Success)
}
}