blob: b11a9abaff2a61917ea13cfd45eac7a4bc168c5e [file] [log] [blame]
package androidx.webgpu
import androidx.test.filters.SmallTest
import androidx.webgpu.helper.WebGpu
import androidx.webgpu.helper.createWebGpu
import kotlinx.coroutines.runBlocking
import org.junit.Test
import junit.framework.TestCase.assertEquals
import org.junit.After
import org.junit.Assert.assertThrows
import org.junit.Before
@Suppress("UNUSED_VARIABLE")
@SmallTest
class ShaderModuleTest {
private lateinit var webGpu: WebGpu
private lateinit var device: GPUDevice
private val invalidShader = """
@vertex fn main() -> @builtin(position) vec4<f32> {
return unknown(0.0, 0.0, 0.0, 1.0);
}
""".trimIndent()
@Before
fun setup() = runBlocking {
webGpu = createWebGpu()
device = webGpu.device
}
@After
fun teardown() {
runCatching { device.destroy() }
webGpu.close()
}
// The info request should still succeed even if compilation fails.
private suspend fun getCompilationInfo(code: String): CompilationInfo {
val shaderModule =
device.createShaderModule(ShaderModuleDescriptor(shaderSourceWGSL = ShaderSourceWGSL(code)))
return shaderModule.getCompilationInfo()
}
@Test
fun getCompilationInfoReturnsNoErrorsForValidAsciiShader() {
val code = """
@vertex fn main() -> @builtin(position) vec4<f32> {
return vec4<f32>(0.0, 0.0, 0.0, 1.0);
}
""".trimIndent()
val info = runBlocking { getCompilationInfo(code) }
val errorCount = info.messages.count { it.type == CompilationMessageType.Error }
assertEquals(0, errorCount)
}
/**
* Verifies that an invalid shader correctly produces a validation error
* and a compilation error message.
*/
@Test
fun invalidShader_producesACompilationError() {
device.pushErrorScope(ErrorFilter.Validation)
val info = runBlocking { getCompilationInfo(invalidShader) }
assertThrows("The operation should result in a validation error",
ValidationException::class.java) {
runBlocking { device.popErrorScope() }
}
val errorCount = info.messages.count { it.type == CompilationMessageType.Error }
assertEquals(1, errorCount)
}
/**
* Verifies that the primary compilation error for the invalid shader is reported
* on the correct line number.
*/
@Test
fun invalidShader_reportsCorrectLineNumber() {
device.pushErrorScope(ErrorFilter.Validation)
val info = runBlocking { getCompilationInfo(invalidShader) }
val errorMessage = info.messages.first { it.type == CompilationMessageType.Error }
val expectedErrorLine = 2L
assertEquals(
"The error should be reported on line $expectedErrorLine.",
expectedErrorLine,
errorMessage.lineNum
)
assertThrows(ValidationException::class.java) {
runBlocking { device.popErrorScope() }
}
}
/**
* Verifies that the absolute character offset reported in the error message is
* mathematically consistent with its reported line number and line position (column).
*/
@Test
fun invalidShader_reportsConsistentOffsetAndLinePosition() {
device.pushErrorScope(ErrorFilter.Validation)
val info = runBlocking { getCompilationInfo(invalidShader) }
val errorMessage = info.messages.first { it.type == CompilationMessageType.Error }
// Pre-computation based on the error message's reported location
val lines = invalidShader.split('\n')
val lineIndex = (errorMessage.lineNum - 1).toInt() // 1-based line number to 0-based index
val positionOffset = (errorMessage.linePos - 1) // 1-based column to 0-based offset
// Calculate the expected offset from scratch
val calculatedOffset = lines.take(lineIndex).sumOf { it.length + 1 } + positionOffset
assertEquals(
"Calculated offset from line/pos should match the message's absolute offset.",
errorMessage.offset,
calculatedOffset
)
assertThrows(ValidationException::class.java) {
runBlocking { device.popErrorScope() }
}
}
}