Integration test in the library.

Change-Id: I48a196b12752eb078214f1f2a3c1d3dbda439ae3
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/193540
Commit-Queue: Jim Blackler <jimblackler@google.com>
Reviewed-by: Jim Blackler <jimblackler@google.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/tools/android/webgpu/build.gradle b/tools/android/webgpu/build.gradle
index 28e65a4..d74675e 100644
--- a/tools/android/webgpu/build.gradle
+++ b/tools/android/webgpu/build.gradle
@@ -35,7 +35,8 @@
     namespace 'net.android.webgpu'
     defaultConfig {
         minSdkVersion 26
-        compileSdk 31
+        compileSdk 34
+        testInstrumentationRunner 'androidx.test.runner.AndroidJUnitRunner'
         externalNativeBuild {
             cmake {
                 arguments '-DANDROID_STL=c++_shared'
@@ -73,8 +74,11 @@
 }
 
 dependencies {
-    implementation 'androidx.core:core-ktx:1.12.0'
-    testImplementation 'junit:junit:4.12'
+    implementation 'androidx.core:core-ktx:1.13.1'
+    testImplementation 'junit:junit:4.13.2'
+    androidTestImplementation 'androidx.test.ext:junit-ktx:1.1.5'
+    androidTestImplementation 'androidx.test:runner:1.5.2'
+    androidTestImplementation 'org.jetbrains.kotlinx:kotlinx-coroutines-test:1.3.0'
 }
 
 project.afterEvaluate {
diff --git a/tools/android/webgpu/src/androidTest/assets/compare/green.png b/tools/android/webgpu/src/androidTest/assets/compare/green.png
new file mode 100644
index 0000000..a271fab
--- /dev/null
+++ b/tools/android/webgpu/src/androidTest/assets/compare/green.png
Binary files differ
diff --git a/tools/android/webgpu/src/androidTest/assets/compare/red.png b/tools/android/webgpu/src/androidTest/assets/compare/red.png
new file mode 100644
index 0000000..387f600
--- /dev/null
+++ b/tools/android/webgpu/src/androidTest/assets/compare/red.png
Binary files differ
diff --git a/tools/android/webgpu/src/androidTest/assets/triangle/shader.wgsl b/tools/android/webgpu/src/androidTest/assets/triangle/shader.wgsl
new file mode 100644
index 0000000..d2b30be
--- /dev/null
+++ b/tools/android/webgpu/src/androidTest/assets/triangle/shader.wgsl
@@ -0,0 +1,8 @@
+@vertex fn vertexMain(@builtin(vertex_index) i : u32) ->
+@builtin(position) vec4f {
+    const pos = array(vec2f(-0.5, -0.5), vec2f(0.5, -0.5), vec2f(0.0, 0.5));
+    return vec4f(pos[i], 0, 1);
+}
+@fragment fn fragmentMain() -> @location(0) vec4f {
+    return vec4f(0.0, 0.4, 1.0, 1.0);
+}
\ No newline at end of file
diff --git a/tools/android/webgpu/src/androidTest/java/net/android/dawn/ImageCompare.kt b/tools/android/webgpu/src/androidTest/java/net/android/dawn/ImageCompare.kt
new file mode 100644
index 0000000..8c0eba2
--- /dev/null
+++ b/tools/android/webgpu/src/androidTest/java/net/android/dawn/ImageCompare.kt
@@ -0,0 +1,35 @@
+package net.android.dawn
+
+import android.graphics.Bitmap
+import kotlin.math.pow
+import kotlin.math.sqrt
+
+fun Int.floatFrom(pos: Int) = (this shr 8 * pos and 255).toFloat() / 255
+
+val Int.blue get() = floatFrom(0)
+val Int.green get() = floatFrom(1)
+val Int.red get() = floatFrom(2)
+val Int.alpha get() = floatFrom(3)
+
+fun imageSimilarity(a: Bitmap, b: Bitmap): Float {
+    if (a.width != b.width || a.height != b.height) {
+        return 0f
+    }
+
+    var sumSimilarity = 0f
+
+    for (y in 0 until a.height) {
+        for (x in 0 until a.width) {
+            val ac = a.getPixel(x, y)
+            val bc = b.getPixel(x, y)
+            sumSimilarity += (1 - sqrt(
+                (ac.red - bc.red).pow(2.0f) +
+                        (ac.green - bc.green).pow(2.0f) +
+                        (ac.blue - bc.blue).pow(2.0f) +
+                        (ac.alpha - bc.alpha).pow(2.0f)
+            ) / 2)
+        }
+    }
+
+    return sumSimilarity / (a.width * a.height)
+}
diff --git a/tools/android/webgpu/src/androidTest/java/net/android/dawn/ImageTest.kt b/tools/android/webgpu/src/androidTest/java/net/android/dawn/ImageTest.kt
new file mode 100644
index 0000000..f7f8847
--- /dev/null
+++ b/tools/android/webgpu/src/androidTest/java/net/android/dawn/ImageTest.kt
@@ -0,0 +1,159 @@
+package net.android.dawn
+
+import android.dawn.*
+import android.dawn.helper.DawnException
+import android.dawn.helper.Util
+import android.dawn.helper.createImageFromGpuTexture
+import android.dawn.helper.streamToString
+import android.graphics.Bitmap
+import android.graphics.BitmapFactory
+import android.os.Environment
+import androidx.test.ext.junit.runners.AndroidJUnit4
+import androidx.test.platform.app.InstrumentationRegistry
+import junit.framework.TestCase.assertEquals
+import kotlinx.coroutines.delay
+import kotlinx.coroutines.launch
+import kotlinx.coroutines.test.runTest
+import org.junit.Test
+import org.junit.runner.RunWith
+import java.io.BufferedOutputStream
+import java.io.File
+import java.io.FileOutputStream
+
+@RunWith(AndroidJUnit4::class)
+class ImageTest {
+    @Test
+    fun imageCompareGreen() {
+        triangleTest(Color(0.2, 0.9, 0.1, 1.0), "green.png")
+    }
+
+    @Test
+    fun imageCompareRed() {
+        triangleTest(Color(0.9, 0.1, 0.2, 1.0), "red.png")
+    }
+
+    private fun triangleTest(color: Color, imageName: String) {
+        runTest {
+            Util  // Hack to force library initialization.
+            val instrumentation = InstrumentationRegistry.getInstrumentation()
+            val appContext = instrumentation.targetContext
+
+            val instance = createInstance()
+
+            val eventProcessor = launch {
+                while (true) {
+                    delay(200)
+                    instance.processEvents()
+                }
+            }
+
+            val (_, adapter, _) = instance.requestAdapter()
+
+            if (adapter == null) {
+                throw DawnException("No adapter available")
+            }
+
+            val (_, device, _) = adapter.requestDevice()
+
+            if (device == null) {
+                throw DawnException("No device available")
+            }
+
+            device.setUncapturedErrorCallback { type, message ->
+                throw DawnException(message)
+            }
+
+            val shaderModule = device.createShaderModule(
+                ShaderModuleDescriptor(
+                    shaderModuleWGSLDescriptor = ShaderModuleWGSLDescriptor(
+                        code = streamToString(appContext.assets.open("triangle/shader.wgsl"))
+                    )
+                )
+            )
+
+            val testTexture = device.createTexture(
+                TextureDescriptor(
+                    size = Extent3D(256, 256),
+                    format = TextureFormat.RGBA8Unorm,
+                    usage = TextureUsage.CopySrc or TextureUsage.RenderAttachment
+                )
+            )
+
+            with(device.queue) {
+                submit(device.createCommandEncoder().use {
+                    with(
+                        it.beginRenderPass(
+                            RenderPassDescriptor(
+                                colorAttachments = arrayOf(
+                                    RenderPassColorAttachment(
+                                        loadOp = LoadOp.Clear,
+                                        storeOp = StoreOp.Store,
+                                        clearValue = color,
+                                        view = testTexture.createView()
+                                    )
+                                )
+                            )
+                        )
+                    ) {
+                        setPipeline(
+                            device.createRenderPipeline(
+                                RenderPipelineDescriptor(
+                                    vertex = VertexState(module = shaderModule),
+                                    primitive = PrimitiveState(
+                                        topology = PrimitiveTopology.TriangleList
+                                    ),
+                                    fragment = FragmentState(
+                                        module = shaderModule,
+                                        targets = arrayOf(
+                                            ColorTargetState(
+                                                format = TextureFormat.RGBA8Unorm
+                                            )
+                                        )
+                                    )
+                                )
+                            )
+                        )
+                        draw(3)
+                        end()
+                    }
+
+                    arrayOf(it.finish())
+                })
+            }
+
+            val bitmap = createImageFromGpuTexture(device, testTexture)
+
+            if (false) {
+                writeReferenceImage(bitmap)
+            }
+
+            val testAssets = instrumentation.context.assets
+            val matched = testAssets.list("compare")!!.filter {
+                imageSimilarity(
+                    bitmap,
+                    BitmapFactory.decodeStream(testAssets.open("compare/$it"))
+                ) > 0.99
+            }
+
+            assertEquals(listOf(imageName), matched)
+
+            device.close()
+            device.destroy()
+            adapter.close()
+
+            eventProcessor.cancel()
+            eventProcessor.join()
+
+            instance.close()
+        }
+    }
+
+    private fun writeReferenceImage(bitmap: Bitmap) {
+        val path = Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DOWNLOADS)
+        val file = File("${path}${File.separator}${"reference.png"}")
+        BufferedOutputStream(FileOutputStream(file)).use {
+            bitmap.compress(Bitmap.CompressFormat.PNG, 100, it)
+            it.close()
+        }
+    }
+}
\ No newline at end of file
diff --git a/tools/android/webgpu/src/main/java/android/dawn/helper/DawnException.kt b/tools/android/webgpu/src/main/java/android/dawn/helper/DawnException.kt
new file mode 100644
index 0000000..4126b64
--- /dev/null
+++ b/tools/android/webgpu/src/main/java/android/dawn/helper/DawnException.kt
@@ -0,0 +1,3 @@
+package android.dawn.helper
+
+class DawnException(message:String): Exception(message)
diff --git a/tools/android/webgpu/src/main/java/android/dawn/helper/Rounding.kt b/tools/android/webgpu/src/main/java/android/dawn/helper/Rounding.kt
new file mode 100644
index 0000000..29430b2
--- /dev/null
+++ b/tools/android/webgpu/src/main/java/android/dawn/helper/Rounding.kt
@@ -0,0 +1,6 @@
+package android.dawn.helper
+
+fun Long.roundDownToNearestMultipleOf(boundary: Int) = this / boundary * boundary
+fun Int.roundDownToNearestMultipleOf(boundary: Int) = this / boundary * boundary
+fun Long.roundUpToNearestMultipleOf(boundary: Int) = (this + boundary - 1) / boundary * boundary
+fun Int.roundUpToNearestMultipleOf(boundary: Int) = (this + boundary - 1) / boundary * boundary
diff --git a/tools/android/webgpu/src/main/java/android/dawn/helper/Streams.kt b/tools/android/webgpu/src/main/java/android/dawn/helper/Streams.kt
new file mode 100644
index 0000000..4a52b54
--- /dev/null
+++ b/tools/android/webgpu/src/main/java/android/dawn/helper/Streams.kt
@@ -0,0 +1,17 @@
+package android.dawn.helper
+
+import java.io.InputStream
+import java.io.InputStreamReader
+import java.nio.charset.StandardCharsets
+import java.util.Scanner
+
+fun streamToString(inputStream: InputStream): String {
+    val scanner = Scanner(
+        InputStreamReader(
+            inputStream,
+            StandardCharsets.UTF_8
+        )
+    ).useDelimiter("\\A")
+    return if (scanner.hasNext()) scanner.next() else ""
+}
+
diff --git a/tools/android/webgpu/src/main/java/android/dawn/helper/Textures.kt b/tools/android/webgpu/src/main/java/android/dawn/helper/Textures.kt
new file mode 100644
index 0000000..be518b5
--- /dev/null
+++ b/tools/android/webgpu/src/main/java/android/dawn/helper/Textures.kt
@@ -0,0 +1,66 @@
+package android.dawn.helper
+
+import android.dawn.*
+import android.graphics.Bitmap
+import java.nio.ByteBuffer
+
+fun createGpuTextureFromBitmap(device: Device, bitmap: Bitmap): Texture {
+    val size = Extent3D(width = bitmap.width, height = bitmap.height)
+    return device.createTexture(
+        TextureDescriptor(
+            size = size,
+            format = TextureFormat.RGBA8Unorm,
+            usage = TextureUsage.TextureBinding or TextureUsage.CopyDst or
+                    TextureUsage.RenderAttachment
+        )
+    ).also { texture ->
+        ByteBuffer.allocateDirect(bitmap.height * bitmap.width * Int.SIZE_BYTES).let { pixels ->
+            bitmap.copyPixelsToBuffer(pixels)
+            device.queue.writeTexture(
+                dataLayout = TextureDataLayout(
+
+                    bytesPerRow = bitmap.width * Int.SIZE_BYTES,
+                    rowsPerImage = bitmap.height
+                ),
+                data = pixels,
+                destination = ImageCopyTexture(texture = texture),
+                writeSize = size
+            )
+        }
+    }
+}
+
+suspend fun createImageFromGpuTexture(device: Device, texture: Texture): Bitmap {
+    if (texture.width % 64 > 0) {
+        throw DawnException("Texture must be a multiple of 64. Was ${texture.width}")
+    }
+
+    val size = texture.width * texture.height * Int.SIZE_BYTES
+    val readbackBuffer = device.createBuffer(
+        BufferDescriptor(
+            size = size.toLong(),
+            usage = BufferUsage.CopyDst or BufferUsage.MapRead
+        )
+    )
+    device.queue.submit(arrayOf(device.createCommandEncoder().run {
+        copyTextureToBuffer(
+            source = ImageCopyTexture(texture = texture),
+            destination = ImageCopyBuffer(
+                layout = TextureDataLayout(
+                    offset = 0,
+                    bytesPerRow = texture.width * Int.SIZE_BYTES,
+                    rowsPerImage = texture.height
+                ), buffer = readbackBuffer
+            ),
+            copySize = Extent3D(width = texture.width, height = texture.height)
+        )
+        finish()
+    }))
+
+    readbackBuffer.mapAsync(MapMode.Read, 0, size.toLong())
+
+    return Bitmap.createBitmap(texture.width, texture.height, Bitmap.Config.ARGB_8888).apply {
+        copyPixelsFromBuffer(readbackBuffer.getConstMappedRange(size = readbackBuffer.size))
+        readbackBuffer.unmap()
+    }
+}