Revert "Kotlin: Async adapter returns only the actual return value."
This reverts commit dfe3624b7ef59534c1b0727940396b51d2aee802.
Reason for revert: Creates copybara circular dependencies; post-review code change caused test breakage (./gradlew connectedAndroidTest)
Original change's description:
> Kotlin: Async adapter returns only the actual return value.
>
> A 'Return' wrapper class is no longer created or used.
>
> The 'status' and 'message' values are used to raise exceptions when relevant,
> but not passed to the client.
>
> Bug: 456671498
>
> Change-Id: If327df443606b9c277a8f061c8c20e3f7e5b6993
> Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/270935
> Commit-Queue: Jim Blackler <jimblackler@google.com>
> Reviewed-by: Tarun Saini <sainitarun@google.com>
TBR=dawn-scoped@luci-project-accounts.iam.gserviceaccount.com,jimblackler@google.com,sainitarun@google.com
No-Presubmit: true
No-Tree-Checks: true
No-Try: true
Change-Id: Ic1f4dc048e46c3c3e6603d1aec8990ee4168c36a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/271574
Auto-Submit: Jim Blackler <jimblackler@google.com>
Reviewed-by: Loko Kung <lokokung@google.com>
Commit-Queue: Jim Blackler <jimblackler@google.com>
Reviewed-by: Jim Blackler <jimblackler@google.com>
diff --git a/generator/templates/art/api_kotlin_async_helpers.kt b/generator/templates/art/api_kotlin_async_helpers.kt
index 5f5c8b8..226b44b 100644
--- a/generator/templates/art/api_kotlin_async_helpers.kt
+++ b/generator/templates/art/api_kotlin_async_helpers.kt
@@ -28,19 +28,32 @@
{% from 'art/api_kotlin_types.kt' import kotlin_annotation, kotlin_declaration, kotlin_definition, check_if_doc_present, generate_kdoc with context %}
+{% set all_callback_info = kdocs.callbacks %}
{% set all_objects = kdocs.objects %}
{% macro async_wrapper(obj, method, callback_arg) %}
//* Generate KDocs
- {% set ns = namespace(status_arg = none, message_arg = none, payload_arg = none) %}
- {% for arg in kotlin_record_members(callback_arg.type.arguments) %}
- {% if arg.name.get() == 'status' %}
- {% set ns.status_arg = arg %}
- {% elif arg.name.get() == 'message' %}
- {% set ns.message_arg = arg %}
- {% else %}
- {% set ns.payload_arg = arg %}
- {% endif %}
- {% endfor %}
+ {% set callback_doc_info = all_callback_info.get(callback_arg.type.name.get()) %}
+ {% set callback_doc = callback_doc_info.doc if callback_doc_info else "" %}
+ {% set callback_args_doc = callback_doc_info.args if callback_doc_info else {} %}
+ {% set callback_args = kotlin_record_members(callback_arg.type.arguments) | list %}
+ {% if check_if_doc_present(callback_doc, "", callback_args_doc, callback_args) == 'True' %}
+ {{- generate_kdoc(callback_doc, "", callback_args_doc, callback_args, line_wrap_prefix = "\n * ") }}
+ {%- endif %}
+
+ {% set return_name = callback_arg.type.name.chunks[:-1] | map('title') | join + 'Return' %}
+ {% set result_args = kotlin_record_members(callback_arg.type.arguments) | list %}
+ //* We make a return class to receive the callback's (possibly multiple) return values.
+ public class {{ return_name }}(
+ {% for arg in result_args %}
+ {{ kotlin_annotation(arg) }} public val {{ as_varName(arg.name) }}: {{ kotlin_declaration(arg) }}{{ ',' if not loop.last }}
+ {% endfor %}) {
+ //* Required for destructuring declarations. These come for free in a 'data' class but
+ //* we don't make it a data class because that can cause binary compatibility issues.
+ {% for arg in result_args %}
+ public operator fun component{{ loop.index }}() : {{ kotlin_declaration(arg) }} =
+ {{- as_varName(arg.name) }}
+ {% endfor %}
+ }
//* Generating KDocs
{% set object_info = all_objects.get(obj.name.get()) %}
@@ -63,44 +76,27 @@
{%- endif %}
//* The wrapped method has executor and callback function stripped out (the wrapper supplies
//* those so the client doesn't have to).
- {{ kotlin_annotation(ns.payload_arg) if ns.payload_arg }}
public suspend fun {{ method.name.camelCase() }}(
{%- for arg in kotlin_record_members(method.arguments) if not (
arg.type.category == 'callback function' or
(arg.type.category == 'kotlin type' and arg.type.name.get() == 'java.util.concurrent.Executor')
) %}
- {{ kotlin_annotation(arg) }} {{ as_varName(arg.name) }}: {{ kotlin_definition(arg) }},
- {%- endfor %}): {{ kotlin_declaration(ns.payload_arg, true) if ns.payload_arg else 'Unit' -}}
- = suspendCancellableCoroutine {
- {{ method.name.camelCase() }}(
+ {{ kotlin_annotation(arg) }} {{ as_varName(arg.name) }}: {{ kotlin_definition(arg) }},
+ {%- endfor %}): {{ return_name }} = suspendCancellableCoroutine {
+ {{ method.name.camelCase() }}(
{%- for arg in kotlin_record_members(method.arguments) %}
{%- if arg.type.category == 'kotlin type' and arg.type.name.get() == 'java.util.concurrent.Executor' -%}
Executor(Runnable::run),
{%- elif arg.name.get() == callback_arg.name.get() %}{
{%- for arg in kotlin_record_members(callback_arg.type.arguments) %}
{{- as_varName(arg.name) }},
- {%- endfor %} ->
- if (!it.isActive) {
- return@{{ method.name.camelCase() }}
- }
- if ({{ as_varName(ns.status_arg.name) }} != {{ ns.status_arg.name.CamelCase() }}.Success
- {%- if ns.message_arg %}|| {{ as_varName(ns.message_arg.name) }}.isNotEmpty(){% endif %}) {
- throw DawnException(
- {%- if ns.status_arg %}status = {{ as_varName(ns.status_arg.name) }},{% endif %}
- {%- if ns.message_arg %}reason = {{ as_varName(ns.message_arg.name) }}{% endif -%}
- )
- }
- it.resume(
- {%- if ns.payload_arg %} {{ as_varName(ns.payload_arg.name) }}
- {%- if ns.payload_arg.optional %} ?: throw DawnException(
- {%- if ns.status_arg %}status = {{ as_varName(ns.status_arg.name) }}, {% endif %}
- reason = "Null value returned")
- {%- endif %}
- {%- else -%}
- Unit
- {%- endif -%}
- )
- }
+ {%- endfor %} -> if (it.isActive) {
+ it.resume({{ return_name }}(
+ //* We make an instance of the callback parameters -> return type wrapper.
+ {%- for arg in result_args %}
+ {{- as_varName(arg.name) }},
+ {%- endfor %}))
+ }}
{%- else -%}
{{- as_varName(arg.name) }},
{%- endif %}
diff --git a/generator/templates/art/api_kotlin_object.kt b/generator/templates/art/api_kotlin_object.kt
index dbc8475..00d7519 100644
--- a/generator/templates/art/api_kotlin_object.kt
+++ b/generator/templates/art/api_kotlin_object.kt
@@ -26,7 +26,6 @@
//* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package {{ kotlin_package }}
-import {{ kotlin_package }}.helper.DawnException
import dalvik.annotation.optimization.FastNative
import java.nio.ByteBuffer
import java.util.concurrent.Executor
diff --git a/generator/templates/art/api_kotlin_types.kt b/generator/templates/art/api_kotlin_types.kt
index d93d9f3..11f2ef5 100644
--- a/generator/templates/art/api_kotlin_types.kt
+++ b/generator/templates/art/api_kotlin_types.kt
@@ -25,9 +25,9 @@
//* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
//* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-{%- macro kotlin_declaration(arg, strip_optional = False) -%}
+{%- macro kotlin_declaration(arg) -%}
{%- set type = arg.type %}
- {%- set optional = arg.optional and not strip_optional %}
+ {%- set optional = arg.optional %}
{%- set default_value = arg.default_value %}
{%- if arg == None -%}
Unit
diff --git a/tools/android/webgpu/src/androidTest/java/androidx/webgpu/AdapterTest.kt b/tools/android/webgpu/src/androidTest/java/androidx/webgpu/AdapterTest.kt
index f08e3d3..7376624 100644
--- a/tools/android/webgpu/src/androidTest/java/androidx/webgpu/AdapterTest.kt
+++ b/tools/android/webgpu/src/androidTest/java/androidx/webgpu/AdapterTest.kt
@@ -17,8 +17,11 @@
val instance = createInstance()
runBlocking {
- val adapter = instance.requestAdapter()
-
+ val result = instance.requestAdapter()
+ val adapter = result.adapter
+ check(result.status == RequestAdapterStatus.Success && adapter != null) {
+ result.message ?: "Error requesting the adapter"
+ }
val adapterInfo = adapter.getInfo()
assertEquals("The backend type should be Vulkan",
diff --git a/tools/android/webgpu/src/androidTest/java/androidx/webgpu/AsyncHelperTest.kt b/tools/android/webgpu/src/androidTest/java/androidx/webgpu/AsyncHelperTest.kt
index 71471c4..351d4d0 100644
--- a/tools/android/webgpu/src/androidTest/java/androidx/webgpu/AsyncHelperTest.kt
+++ b/tools/android/webgpu/src/androidTest/java/androidx/webgpu/AsyncHelperTest.kt
@@ -2,13 +2,11 @@
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.filters.SmallTest
-import androidx.webgpu.helper.DawnException
import androidx.webgpu.helper.WebGpu
import androidx.webgpu.helper.createWebGpu
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.junit.Assert.assertEquals
-import org.junit.Assert.assertThrows
import org.junit.Assume.assumeFalse
import org.junit.Before
import org.junit.Test
@@ -22,15 +20,6 @@
private lateinit var webGpu: WebGpu
private lateinit var device: GPUDevice
- private val BASIC_SHADER = """
- @vertex fn vertexMain(@builtin(vertex_index) i : u32) ->
- @builtin(position) vec4f {
- return vec4f();
- }
- @fragment fn fragmentMain() -> @location(0) vec4f {
- return vec4f();
- } """
-
@Before
fun setup() {
runBlocking {
@@ -47,36 +36,43 @@
ShaderModuleDescriptor(shaderSourceWGSL = ShaderSourceWGSL(""))
)
- val exception = assertThrows(DawnException::class.java) {
- runBlocking {
- /* Call an asynchronous method, converted from a callback pattern by a helper. */
- device.createRenderPipelineAsync(
- RenderPipelineDescriptor(vertex = VertexState(module = shaderModule))
- )
- }
- }
-
- assertEquals(
- """Create render pipeline (async) should fail when no shader entry point exists.
- The result was: ${exception.status}""",
- CreatePipelineAsyncStatus.ValidationError,
- exception.status
+ /* Call an asynchronous method, converted from a callback pattern by a helper. */
+ val result = device.createRenderPipelineAsync(
+ RenderPipelineDescriptor(vertex = VertexState(module = shaderModule))
)
+
+ assert(result.status == CreatePipelineAsyncStatus.ValidationError) {
+ """Create render pipeline (async) should fail when no shader entry point exists.
+ The result was: ${result.status}"""
+ }
}
}
@Test
fun asyncMethodTestValidationPasses() {
runBlocking {
- /* Set up a valid shader module and descriptor */
+ /* Set up a shader module to support the async call. */
val shaderModule = device.createShaderModule(
- ShaderModuleDescriptor(shaderSourceWGSL = ShaderSourceWGSL(BASIC_SHADER))
+ ShaderModuleDescriptor(
+ shaderSourceWGSL = ShaderSourceWGSL(
+ """
+@vertex fn vertexMain(@builtin(vertex_index) i : u32) ->
+@builtin(position) vec4f {
+ return vec4f();
+}
+@fragment fn fragmentMain() -> @location(0) vec4f {
+ return vec4f();
+}
+ """
+ )
+ )
)
/* Call an asynchronous method, converted from a callback pattern by a helper. */
- device.createRenderPipelineAsync(
+ val result = device.createRenderPipelineAsync(
RenderPipelineDescriptor(
- vertex = VertexState(module = shaderModule), fragment = FragmentState(
+ vertex = VertexState(module = shaderModule),
+ fragment = FragmentState(
module = shaderModule,
targets = arrayOf(ColorTargetState(format = TextureFormat.RGBA8Unorm))
)
@@ -84,6 +80,8 @@
)
/* Create render pipeline (async) should pass with a simple shader.. */
+ assertEquals(result.status, CreatePipelineAsyncStatus.Success)
+
}
}
@@ -92,18 +90,13 @@
runBlocking {
val shaderModule = device.createShaderModule(
- ShaderModuleDescriptor(shaderSourceWGSL = ShaderSourceWGSL(BASIC_SHADER))
+ ShaderModuleDescriptor(shaderSourceWGSL = ShaderSourceWGSL(""))
)
/* Launch the function in a new coroutine, giving us a job handle we can cancel. */
val job = launch {
- var unused = device.createRenderPipelineAsync(
- RenderPipelineDescriptor(vertex = VertexState(module = shaderModule),
- fragment = FragmentState(
- module = shaderModule,
- targets = arrayOf(ColorTargetState(format = TextureFormat.RGBA8Unorm))
- )
- )
+ val unused = device.createRenderPipelineAsync(
+ RenderPipelineDescriptor(vertex = VertexState(module = shaderModule))
)
hasReturned.set(true)
}
diff --git a/tools/android/webgpu/src/androidTest/java/androidx/webgpu/BufferTest.kt b/tools/android/webgpu/src/androidTest/java/androidx/webgpu/BufferTest.kt
index 5b83455..be67b28 100644
--- a/tools/android/webgpu/src/androidTest/java/androidx/webgpu/BufferTest.kt
+++ b/tools/android/webgpu/src/androidTest/java/androidx/webgpu/BufferTest.kt
@@ -153,8 +153,12 @@
)
val gpuCommand = commandEncoder.finish()
queue.submit(arrayOf(gpuCommand))
- runBlocking { queue.onSubmittedWorkDone() }
- runBlocking { gpuReadBuffer.mapAsync(mode = MapMode.Read, size = bufferSize, offset = 0) }
+ val result = runBlocking { queue.onSubmittedWorkDone() }
+ assertEquals(QueueWorkDoneStatus.Success, result.status)
+
+ val mapStatus =
+ runBlocking { gpuReadBuffer.mapAsync(mode = MapMode.Read, size = bufferSize, offset = 0) }
+ assertEquals(MapAsyncStatus.Success, mapStatus.status)
val arrayBuffer = gpuReadBuffer.getConstMappedRange(size = bufferSize).asFloatBuffer()
gpuReadBuffer.unmap()
@@ -186,7 +190,8 @@
assertEquals(BufferMapState.Unmapped, buffer.mapState)
- runBlocking { buffer.mapAsync(MapMode.Read, 0, bufferSize) }
+ val mapResult = runBlocking { buffer.mapAsync(MapMode.Read, 0, bufferSize) }
+ assertEquals(MapAsyncStatus.Success, mapResult.status)
assertEquals(BufferMapState.Mapped, buffer.mapState)
diff --git a/tools/android/webgpu/src/androidTest/java/androidx/webgpu/CommandEncoderTest.kt b/tools/android/webgpu/src/androidTest/java/androidx/webgpu/CommandEncoderTest.kt
index d419613..6ead622 100644
--- a/tools/android/webgpu/src/androidTest/java/androidx/webgpu/CommandEncoderTest.kt
+++ b/tools/android/webgpu/src/androidTest/java/androidx/webgpu/CommandEncoderTest.kt
@@ -9,7 +9,6 @@
import kotlinx.coroutines.runBlocking
import org.junit.After
import org.junit.Assert.assertArrayEquals
-import org.junit.Assert.assertEquals
import org.junit.Assert.assertThrows
import org.junit.Before
import org.junit.Test
@@ -69,9 +68,10 @@
encoder.popDebugGroup()
val error = runBlocking { device.popErrorScope() }
- TestCase.assertEquals(
+ TestCase.assertTrue(
"Expected no error for balanced push/pop debug group",
- ErrorType.NoError, error
+ error.status == PopErrorScopeStatus.Success &&
+ error.type == ErrorType.NoError
)
val unusedCommandBuffer = encoder.finish()
}
@@ -95,9 +95,9 @@
val unusedCommandBuffer = encoder.finish()
val error = runBlocking { device.popErrorScope() }
- TestCase.assertEquals(
+ TestCase.assertTrue(
"Expected a validation error on .finish() due to an earlier nested pass attempt",
- ErrorType.Validation, error
+ error.type == ErrorType.Validation,
)
}
@@ -198,7 +198,8 @@
queue.submit(arrayOf(commandBuffer))
- runBlocking { readbackBuffer.mapAsync(MapMode.Read, 0, bufferSize) }
+ val mapResult = runBlocking { readbackBuffer.mapAsync(MapMode.Read, 0, bufferSize) }
+ TestCase.assertEquals(MapAsyncStatus.Success, mapResult.status)
val mappedData: ByteBuffer = readbackBuffer.getConstMappedRange(0, bufferSize)
@@ -243,9 +244,10 @@
encoder.clearBuffer(buffer, 0, 16)
val error = runBlocking { device.popErrorScope() }
- assertEquals(
+ TestCase.assertTrue(
"Expected clearBuffer to succeed",
- ErrorType.NoError, error
+ error.status == PopErrorScopeStatus.Success &&
+ error.type == ErrorType.NoError
)
val unusedCommandBuffer = encoder.finish()
}
@@ -278,9 +280,10 @@
encoder.resolveQuerySet(querySet, 0, 1, destination, 0)
val error = runBlocking { device.popErrorScope() }
- assertEquals(
+ TestCase.assertTrue(
"Expected resolveQuerySet to succeed",
- ErrorType.NoError, error
+ error.status == PopErrorScopeStatus.Success &&
+ error.type == ErrorType.NoError
)
val unusedCommandBuffer = encoder.finish()
}
@@ -300,9 +303,10 @@
encoder.insertDebugMarker("MyDebugMarker")
val error = runBlocking { device.popErrorScope() }
- assertEquals(
+ TestCase.assertTrue(
"Expected insertDebugMarker to succeed",
- ErrorType.NoError, error
+ error.status == PopErrorScopeStatus.Success &&
+ error.type == ErrorType.NoError
)
val unusedCommandBuffer = encoder.finish()
}
diff --git a/tools/android/webgpu/src/androidTest/java/androidx/webgpu/ComputePassEncoderTest.kt b/tools/android/webgpu/src/androidTest/java/androidx/webgpu/ComputePassEncoderTest.kt
index 7643c1c..a617dc2 100644
--- a/tools/android/webgpu/src/androidTest/java/androidx/webgpu/ComputePassEncoderTest.kt
+++ b/tools/android/webgpu/src/androidTest/java/androidx/webgpu/ComputePassEncoderTest.kt
@@ -75,8 +75,8 @@
passEncoder.end()
val unusedCommandBuffer = encoder.finish()
- val error = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.NoError, error)
+ val errorScope = runBlocking { device.popErrorScope() }
+ assertEquals(ErrorType.NoError, errorScope.type)
}
/**
@@ -92,9 +92,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedCommandBuffer = encoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(ErrorType.Validation, errorScope.type)
}
/**
@@ -111,9 +111,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedCommandBuffer = encoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.NoError, error)
+ assertEquals(ErrorType.NoError, errorScope.type)
}
/**
@@ -132,9 +132,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedCommandBuffer = encoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.NoError, error)
+ assertEquals(ErrorType.NoError, errorScope.type)
}
/**
@@ -158,9 +158,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedCommandBuffer = encoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(ErrorType.Validation, errorScope.type)
invalidBuffer.destroy()
}
@@ -187,9 +187,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedCommandBuffer = encoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.NoError, error)
+ assertEquals(ErrorType.NoError, errorScope.type)
validBuffer.destroy()
}
@@ -208,9 +208,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedCommandBuffer = encoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(ErrorType.Validation, errorScope.type)
}
/**
@@ -227,8 +227,8 @@
device.pushErrorScope(ErrorFilter.Validation)
passEncoder.end() // Second call (invalid).
val unusedCommandBuffer = encoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(ErrorType.Validation, errorScope.type)
}
}
\ No newline at end of file
diff --git a/tools/android/webgpu/src/androidTest/java/androidx/webgpu/DeviceTest.kt b/tools/android/webgpu/src/androidTest/java/androidx/webgpu/DeviceTest.kt
index 84a7f28..e5fa158 100644
--- a/tools/android/webgpu/src/androidTest/java/androidx/webgpu/DeviceTest.kt
+++ b/tools/android/webgpu/src/androidTest/java/androidx/webgpu/DeviceTest.kt
@@ -56,7 +56,8 @@
)
val error = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(error.type, ErrorType.Validation)
+ assertEquals(error.status, PopErrorScopeStatus.Success)
}
@@ -144,7 +145,8 @@
)
)
val error = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(error.type, ErrorType.Validation)
+ assertEquals(error.status, PopErrorScopeStatus.Success)
}
@Test
@@ -168,7 +170,8 @@
)
val error = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(error.type, ErrorType.Validation)
+ assertEquals(error.status, PopErrorScopeStatus.Success)
}
/**
@@ -203,7 +206,8 @@
)
)
val error = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(error.type, ErrorType.Validation)
+ assertEquals(error.status, PopErrorScopeStatus.Success)
}
@Test
@@ -216,7 +220,8 @@
)
)
val error = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(error.type, ErrorType.Validation)
+ assertEquals(error.status, PopErrorScopeStatus.Success)
}
@Test
@@ -244,6 +249,7 @@
val unusedSampler = device.createSampler(invalidDescriptor)
val error = device.popErrorScope()
- assertEquals(error, ErrorType.Validation)
+ assertEquals(error.type, ErrorType.Validation)
+ assertEquals(error.status, PopErrorScopeStatus.Success)
}
}
diff --git a/tools/android/webgpu/src/androidTest/java/androidx/webgpu/QuerySetTest.kt b/tools/android/webgpu/src/androidTest/java/androidx/webgpu/QuerySetTest.kt
index d137fac..eb2f645 100644
--- a/tools/android/webgpu/src/androidTest/java/androidx/webgpu/QuerySetTest.kt
+++ b/tools/android/webgpu/src/androidTest/java/androidx/webgpu/QuerySetTest.kt
@@ -87,9 +87,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedQuerySet =
device.createQuerySet(QuerySetDescriptor(type = QueryType.Occlusion, count = -1))
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(ErrorType.Validation, errorScope.type)
}
@Test
@@ -133,9 +133,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedCommandBuffer = encoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.NoError, error)
+ assertEquals(ErrorType.NoError, errorScope.type)
querySet.destroy()
destinationBuffer.destroy()
@@ -156,9 +156,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedCommandBuffer = encoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(ErrorType.Validation, errorScope.type)
}
@Test
@@ -174,9 +174,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedCommandBuffer = encoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(ErrorType.Validation, errorScope.type)
}
@Test
@@ -193,9 +193,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedCommandBuffer = encoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(ErrorType.Validation, errorScope.type)
}
@Test
@@ -213,8 +213,8 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedCommandBuffer = encoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(ErrorType.Validation, errorScope.type)
}
}
\ No newline at end of file
diff --git a/tools/android/webgpu/src/androidTest/java/androidx/webgpu/RenderBundleEncoderTest.kt b/tools/android/webgpu/src/androidTest/java/androidx/webgpu/RenderBundleEncoderTest.kt
index 0a8162e..c6b43af 100644
--- a/tools/android/webgpu/src/androidTest/java/androidx/webgpu/RenderBundleEncoderTest.kt
+++ b/tools/android/webgpu/src/androidTest/java/androidx/webgpu/RenderBundleEncoderTest.kt
@@ -126,8 +126,8 @@
bundleEncoder.insertDebugMarker("Marker Inside Bundle")
device.pushErrorScope(ErrorFilter.Validation)
val unusedRenderBundle = bundleEncoder.finish()
- val error = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.NoError, error)
+ val errorScope = runBlocking { device.popErrorScope() }
+ assertEquals(ErrorType.NoError, errorScope.type)
}
@Test
@@ -137,9 +137,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedRenderBundle = bundleEncoder.finish() // Deferred error caught here.
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(ErrorType.Validation, errorScope.type)
}
@Test
@@ -150,9 +150,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedRenderBundle = bundleEncoder.finish() // Should succeed.
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.NoError, error)
+ assertEquals(ErrorType.NoError, errorScope.type)
}
@Test
@@ -162,9 +162,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedRenderBundle = bundleEncoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(ErrorType.Validation, errorScope.type)
}
@Test
@@ -175,9 +175,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedRenderBundle = bundleEncoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.NoError, error)
+ assertEquals(ErrorType.NoError, errorScope.type)
}
@Test
@@ -190,9 +190,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedRenderBundle = bundleEncoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(ErrorType.Validation, errorScope.type)
invalidBuffer.destroy()
}
@@ -206,9 +206,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedRenderBundle = bundleEncoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.NoError, error)
+ assertEquals(ErrorType.NoError, errorScope.type)
validBuffer.destroy()
}
@@ -220,9 +220,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedRenderBundle = bundleEncoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(ErrorType.Validation, errorScope.type)
}
@Test
@@ -235,9 +235,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val bundle = bundleEncoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.NoError, error)
+ assertEquals(ErrorType.NoError, errorScope.type)
indexBuffer.destroy()
bundle.close()
}
@@ -253,9 +253,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedRenderBundle = bundleEncoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(ErrorType.Validation, errorScope.type)
invalidBuffer.destroy()
}
@@ -268,9 +268,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val bundle = bundleEncoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.NoError, error)
+ assertEquals(ErrorType.NoError, errorScope.type)
indirectBuffer.destroy()
bundle.close()
}
@@ -284,9 +284,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedRenderBundle = bundleEncoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(ErrorType.Validation, errorScope.type)
indirectBuffer.destroy()
}
@@ -302,9 +302,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedRenderBundle = bundleEncoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.NoError, error)
+ assertEquals(ErrorType.NoError, errorScope.type)
indirectBuffer.destroy()
indexBuffer.destroy()
}
@@ -345,9 +345,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedRenderBundle = bundleEncoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(ErrorType.Validation, errorScope.type)
}
/**
@@ -407,9 +407,9 @@
device.pushErrorScope(ErrorFilter.Validation)
// Finish recording. Validation occurs here.
val unusedRenderBundle = bundleEncoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.NoError, error)
+ assertEquals(ErrorType.NoError, errorScope.type)
}
@Test
diff --git a/tools/android/webgpu/src/androidTest/java/androidx/webgpu/RenderPassEncoderTest.kt b/tools/android/webgpu/src/androidTest/java/androidx/webgpu/RenderPassEncoderTest.kt
index a5f24c1..ec77559 100644
--- a/tools/android/webgpu/src/androidTest/java/androidx/webgpu/RenderPassEncoderTest.kt
+++ b/tools/android/webgpu/src/androidTest/java/androidx/webgpu/RenderPassEncoderTest.kt
@@ -164,9 +164,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedCommandBuffer = encoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.NoError, error)
+ assertEquals(ErrorType.NoError, errorScope.type)
}
@Test
@@ -178,9 +178,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedCommandBuffer = encoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(ErrorType.Validation, errorScope.type)
}
@Test
@@ -192,9 +192,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedCommandBuffer = encoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(ErrorType.Validation, errorScope.type)
}
@Test
@@ -209,9 +209,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedCommandBuffer = encoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(ErrorType.Validation, errorScope.type)
invalidBuffer.destroy()
}
@@ -225,9 +225,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedCommandBuffer = encoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(ErrorType.Validation, errorScope.type)
}
@Test
@@ -242,9 +242,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedCommandBuffer = encoder.finish()
- val errorType = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.NoError, errorType)
+ assertEquals(ErrorType.NoError, errorScope.type)
indexBuffer.destroy()
}
@@ -261,9 +261,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedCommandBuffer = encoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(ErrorType.Validation, errorScope.type)
invalidBuffer.destroy()
}
@@ -278,9 +278,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedCommandBuffer = encoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(ErrorType.Validation, errorScope.type)
indirectBuffer.destroy()
}
@@ -297,9 +297,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedCommandBuffer = encoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.NoError, error)
+ assertEquals(ErrorType.NoError, errorScope.type)
indirectBuffer.destroy()
indexBuffer.destroy()
}
@@ -361,9 +361,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedCommandBuffer = encoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.NoError, error)
+ assertEquals(ErrorType.NoError, errorScope.type)
querySet.destroy()
depthPipeline.close()
@@ -381,9 +381,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedCommandBuffer = encoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.NoError, error)
+ assertEquals(ErrorType.NoError, errorScope.type)
}
@Test
@@ -394,9 +394,9 @@
device.pushErrorScope(ErrorFilter.Validation)
passEncoder.end() // Second call (invalid).
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.Validation, error)
+ assertEquals(ErrorType.Validation, errorScope.type)
}
@Test
@@ -417,9 +417,9 @@
device.pushErrorScope(ErrorFilter.Validation)
val unusedCommandBuffer = encoder.finish()
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals(ErrorType.NoError, error)
+ assertEquals(ErrorType.NoError, errorScope.type)
bundle.close()
}
diff --git a/tools/android/webgpu/src/androidTest/java/androidx/webgpu/ShaderModuleTest.kt b/tools/android/webgpu/src/androidTest/java/androidx/webgpu/ShaderModuleTest.kt
index 1f4d8b0..8991068 100644
--- a/tools/android/webgpu/src/androidTest/java/androidx/webgpu/ShaderModuleTest.kt
+++ b/tools/android/webgpu/src/androidTest/java/androidx/webgpu/ShaderModuleTest.kt
@@ -38,7 +38,8 @@
private suspend fun getCompilationInfo(code: String): CompilationInfo {
val shaderModule =
device.createShaderModule(ShaderModuleDescriptor(shaderSourceWGSL = ShaderSourceWGSL(code)))
- return shaderModule.getCompilationInfo()
+ val (_, info) = shaderModule.getCompilationInfo()
+ return info
}
@@ -62,11 +63,10 @@
fun invalidShader_producesACompilationError() {
device.pushErrorScope(ErrorFilter.Validation)
val info = runBlocking { getCompilationInfo(invalidShader) }
- val error = runBlocking { device.popErrorScope() }
+ val errorScope = runBlocking { device.popErrorScope() }
- assertEquals("The operation should result in a validation error",
- ErrorType.Validation, error
- )
+ // Assert that the operation resulted in a validation error
+ assertEquals(ErrorType.Validation, errorScope.type)
val errorCount = info.messages.count { it.type == CompilationMessageType.Error }
assertEquals(1, errorCount)
diff --git a/tools/android/webgpu/src/main/java/androidx/webgpu/helper/DawnException.kt b/tools/android/webgpu/src/main/java/androidx/webgpu/helper/DawnException.kt
index 11634c2..4a8b14e 100644
--- a/tools/android/webgpu/src/main/java/androidx/webgpu/helper/DawnException.kt
+++ b/tools/android/webgpu/src/main/java/androidx/webgpu/helper/DawnException.kt
@@ -1,10 +1,3 @@
package androidx.webgpu.helper
-import androidx.webgpu.Status
-
-public class DawnException(public val reason: String = "", public val status: Int? = null) :
- Exception(
- (if (status != null) "${Status.toString(status)}:" else "") + reason
- ) {
-
-}
\ No newline at end of file
+public class DawnException(message:String): Exception(message)
diff --git a/tools/android/webgpu/src/main/java/androidx/webgpu/helper/Textures.kt b/tools/android/webgpu/src/main/java/androidx/webgpu/helper/Textures.kt
index b4109b1..5a46bb9 100644
--- a/tools/android/webgpu/src/main/java/androidx/webgpu/helper/Textures.kt
+++ b/tools/android/webgpu/src/main/java/androidx/webgpu/helper/Textures.kt
@@ -56,7 +56,11 @@
it.finish()
}))
- readbackBuffer.mapAsync(MapMode.Read, 0, size.toLong())
+ val bufferMapReturn = readbackBuffer.mapAsync(MapMode.Read, 0, size.toLong())
+
+ if (bufferMapReturn.status != MapAsyncStatus.Success) {
+ throw DawnException("Failed to map buffer: ${bufferMapReturn.message}")
+ }
return Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888).apply {
copyPixelsFromBuffer(readbackBuffer.getConstMappedRange(size = readbackBuffer.size))
diff --git a/tools/android/webgpu/src/main/java/androidx/webgpu/helper/WebGpu.kt b/tools/android/webgpu/src/main/java/androidx/webgpu/helper/WebGpu.kt
index 18bc274..20e3b5f 100644
--- a/tools/android/webgpu/src/main/java/androidx/webgpu/helper/WebGpu.kt
+++ b/tools/android/webgpu/src/main/java/androidx/webgpu/helper/WebGpu.kt
@@ -105,7 +105,12 @@
instance: GPUInstance,
options: RequestAdapterOptions = RequestAdapterOptions(backendType = BackendType.Vulkan),
): GPUAdapter {
- return instance.requestAdapter(options)
+ val result = instance.requestAdapter(options)
+ val adapter = result.adapter
+ check(result.status == RequestAdapterStatus.Success && adapter != null) {
+ result. message.ifEmpty { "Error requesting the adapter: $result.status" }
+ }
+ return adapter
}
private suspend inline fun requestDevice(
@@ -119,7 +124,12 @@
if (deviceDescriptor.uncapturedErrorCallback == null) {
deviceDescriptor.uncapturedErrorCallback = defaultUncapturedErrorCallback
}
- return adapter.requestDevice(deviceDescriptor)
+ val result = adapter.requestDevice(deviceDescriptor)
+ val device = result.device
+ check(result.status == RequestDeviceStatus.Success && device != null) {
+ result.message.ifEmpty { "Error requesting the device: $result.status" }
+ }
+ return device
}
private val defaultUncapturedErrorCallback get(): UncapturedErrorCallback {