From b045fac84b5a6914074fd28f229807eadb4e0f81 Mon Sep 17 00:00:00 2001 From: twiceYuan Date: Tue, 25 Jul 2023 14:46:46 +0800 Subject: [PATCH] Update to Gradle 7/AGP 7/Kotlin 1.8/Android 13, replace DataBinding to ViewBinding --- bert/build.gradle | 20 ++-- bert/download.gradle | 2 +- bert/src/main/AndroidManifest.xml | 6 +- .../bertqa/ui/DatasetListActivity.java | 3 +- .../bertqa/ui/QaActivity.java | 3 +- .../bertqa/ui/QuestionAdapter.java | 3 +- build.gradle | 10 +- gpt2/build.gradle | 16 ++- gpt2/download.gradle | 2 +- gpt2/src/main/AndroidManifest.xml | 3 +- .../android_transformers/gpt2/MainActivity.kt | 46 ++++++-- .../gpt2/ml/GPT2Client.kt | 34 ++---- .../gpt2/tokenization/GPT2Tokenizer.kt | 2 +- gpt2/src/main/res/layout/activity_main.xml | 102 ++++++++---------- gradle/wrapper/gradle-wrapper.properties | 2 +- 15 files changed, 125 insertions(+), 129 deletions(-) diff --git a/bert/build.gradle b/bert/build.gradle index 6ee4bda..1059395 100644 --- a/bert/build.gradle +++ b/bert/build.gradle @@ -1,12 +1,13 @@ apply plugin: 'com.android.application' android { - compileSdkVersion 29 - buildToolsVersion "29.0.2" + compileSdkVersion 33 + namespace "co.huggingface.android_transformers.bertqa" + defaultConfig { applicationId "co.huggingface.android_transformers.bert" minSdkVersion 26 - targetSdkVersion 29 + targetSdkVersion 33 versionCode 1 versionName "1.0" testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" @@ -21,8 +22,8 @@ android { noCompress "tflite" } compileOptions { - sourceCompatibility JavaVersion.VERSION_1_8 - targetCompatibility JavaVersion.VERSION_1_8 + sourceCompatibility JavaVersion.VERSION_11 + targetCompatibility JavaVersion.VERSION_11 } // If you find lint problem like: // * What went wrong: @@ -53,7 +54,7 @@ dependencies { implementation 'com.google.guava:guava:28.1-android' // implementation 'org.tensorflow:tensorflow-lite:2.0.0' // implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:0.0.0-nightly' - compile(name: 'tensorflow-lite-with-select-tf-ops-0.0.0-nightly', ext: 'aar') + implementation files('libs/tensorflow-lite-with-select-tf-ops-0.0.0-nightly.aar') testImplementation 'junit:junit:4.12' testImplementation 'androidx.test:core:1.2.0' @@ -62,10 +63,3 @@ dependencies { androidTestImplementation 'androidx.test:runner:1.2.0' androidTestImplementation 'androidx.test.espresso:espresso-core:3.2.0' } - -repositories { - flatDir { - dirs 'libs' - } - mavenCentral() -} diff --git a/bert/download.gradle b/bert/download.gradle index 902c7d8..904458f 100644 --- a/bert/download.gradle +++ b/bert/download.gradle @@ -1,6 +1,6 @@ apply plugin: 'de.undercouch.download' -task downloadLiteModel { +task downloadLiteModel(type: Download) { def downloadFiles = [ 'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-384.tflite': 'model.tflite', // 'https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-384-fp16.tflite': 'model.tflite', // FP16 quantization version diff --git a/bert/src/main/AndroidManifest.xml b/bert/src/main/AndroidManifest.xml index 182abff..2c2960b 100644 --- a/bert/src/main/AndroidManifest.xml +++ b/bert/src/main/AndroidManifest.xml @@ -1,7 +1,7 @@ + package="co.huggingface.android_transformers.bertqa"> diff --git a/bert/src/main/java/co/huggingface/android_transformers/bertqa/ui/DatasetListActivity.java b/bert/src/main/java/co/huggingface/android_transformers/bertqa/ui/DatasetListActivity.java index 781c49d..a6a6daf 100644 --- a/bert/src/main/java/co/huggingface/android_transformers/bertqa/ui/DatasetListActivity.java +++ b/bert/src/main/java/co/huggingface/android_transformers/bertqa/ui/DatasetListActivity.java @@ -19,7 +19,8 @@ import android.widget.ArrayAdapter; import android.widget.ListView; -import co.huggingface.android_transformers.R; + +import co.huggingface.android_transformers.bertqa.R; import co.huggingface.android_transformers.bertqa.ml.LoadDatasetClient; /** diff --git a/bert/src/main/java/co/huggingface/android_transformers/bertqa/ui/QaActivity.java b/bert/src/main/java/co/huggingface/android_transformers/bertqa/ui/QaActivity.java index a40da06..a6dbf65 100644 --- a/bert/src/main/java/co/huggingface/android_transformers/bertqa/ui/QaActivity.java +++ b/bert/src/main/java/co/huggingface/android_transformers/bertqa/ui/QaActivity.java @@ -39,7 +39,8 @@ import com.google.android.material.textfield.TextInputEditText; import java.util.List; import java.util.Locale; -import co.huggingface.android_transformers.R; + +import co.huggingface.android_transformers.bertqa.R; import co.huggingface.android_transformers.bertqa.ml.LoadDatasetClient; import co.huggingface.android_transformers.bertqa.ml.QaAnswer; import co.huggingface.android_transformers.bertqa.ml.QaClient; diff --git a/bert/src/main/java/co/huggingface/android_transformers/bertqa/ui/QuestionAdapter.java b/bert/src/main/java/co/huggingface/android_transformers/bertqa/ui/QuestionAdapter.java index b0e4802..eb59e4f 100644 --- a/bert/src/main/java/co/huggingface/android_transformers/bertqa/ui/QuestionAdapter.java +++ b/bert/src/main/java/co/huggingface/android_transformers/bertqa/ui/QuestionAdapter.java @@ -21,7 +21,8 @@ import android.view.View; import android.view.ViewGroup; import com.google.android.material.chip.Chip; -import co.huggingface.android_transformers.R; + +import co.huggingface.android_transformers.bertqa.R; /** Adapter class to show question suggestion chips. */ public class QuestionAdapter extends RecyclerView.Adapter { diff --git a/build.gradle b/build.gradle index cf487e5..c70bcde 100644 --- a/build.gradle +++ b/build.gradle @@ -1,14 +1,14 @@ // Top-level build file where you can add configuration options common to all sub-projects/modules. buildscript { - ext.kotlin_version = '1.3.61' + ext.kotlin_version = '1.8.20' repositories { google() - jcenter() + mavenCentral() } dependencies { - classpath 'com.android.tools.build:gradle:3.5.3' - classpath 'de.undercouch:gradle-download-task:4.0.0' + classpath 'com.android.tools.build:gradle:7.2.2' + classpath 'de.undercouch:gradle-download-task:5.0.4' classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version" // NOTE: Do not place your application dependencies here; they belong // in the individual module build.gradle files @@ -18,7 +18,7 @@ buildscript { allprojects { repositories { google() - jcenter() + mavenCentral() } } diff --git a/gpt2/build.gradle b/gpt2/build.gradle index 42c7a4b..b5a1032 100644 --- a/gpt2/build.gradle +++ b/gpt2/build.gradle @@ -1,17 +1,15 @@ apply plugin: 'com.android.application' apply plugin: 'kotlin-android' -apply plugin: 'kotlin-android-extensions' -apply plugin: 'kotlin-kapt' android { - compileSdkVersion 29 - buildToolsVersion "29.0.2" - + compileSdkVersion 33 + namespace "co.huggingface.android_transformers.gpt2" defaultConfig { applicationId "co.huggingface.android_transformers.gpt2" + namespace "co.huggingface.android_transformers.gpt2" minSdkVersion 26 - targetSdkVersion 29 + targetSdkVersion 33 versionCode 1 versionName "1.0" @@ -22,12 +20,12 @@ android { noCompress "tflite" } - dataBinding { - enabled = true + buildFeatures { + viewBinding true } kotlinOptions { - jvmTarget = JavaVersion.VERSION_1_8 + jvmTarget = JavaVersion.VERSION_11 } buildTypes { diff --git a/gpt2/download.gradle b/gpt2/download.gradle index afb8876..d7f3d1c 100644 --- a/gpt2/download.gradle +++ b/gpt2/download.gradle @@ -1,6 +1,6 @@ apply plugin: 'de.undercouch.download' -task downloadLiteModel { +task downloadLiteModel(type: Download) { def downloadFiles = [ "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json": "gpt2-vocab.json", "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt": "gpt2-merges.txt", diff --git a/gpt2/src/main/AndroidManifest.xml b/gpt2/src/main/AndroidManifest.xml index 0ddae47..49fe99a 100644 --- a/gpt2/src/main/AndroidManifest.xml +++ b/gpt2/src/main/AndroidManifest.xml @@ -11,7 +11,8 @@ android:supportsRtl="true" android:theme="@style/AppTheme" tools:ignore="GoogleAppIndexingWarning"> - + diff --git a/gpt2/src/main/java/co/huggingface/android_transformers/gpt2/MainActivity.kt b/gpt2/src/main/java/co/huggingface/android_transformers/gpt2/MainActivity.kt index 542af49..d5aa877 100644 --- a/gpt2/src/main/java/co/huggingface/android_transformers/gpt2/MainActivity.kt +++ b/gpt2/src/main/java/co/huggingface/android_transformers/gpt2/MainActivity.kt @@ -1,9 +1,12 @@ package co.huggingface.android_transformers.gpt2 -import androidx.appcompat.app.AppCompatActivity import android.os.Bundle +import android.text.Spannable +import android.text.SpannableStringBuilder +import android.widget.TextView import androidx.activity.viewModels -import androidx.databinding.DataBindingUtil +import androidx.appcompat.app.AppCompatActivity +import androidx.core.content.res.ResourcesCompat import co.huggingface.android_transformers.gpt2.databinding.ActivityMainBinding class MainActivity : AppCompatActivity() { @@ -11,14 +14,39 @@ class MainActivity : AppCompatActivity() { override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) + val binding: ActivityMainBinding = ActivityMainBinding.inflate(layoutInflater) + setContentView(binding.root) - val binding: ActivityMainBinding - = DataBindingUtil.setContentView(this, R.layout.activity_main) - - // Bind layout with ViewModel - binding.vm = gpt2 + binding.autocompleteButton.setOnClickListener { + gpt2.launchAutocomplete() + } + binding.shuffleButton.setOnClickListener { + gpt2.refreshPrompt() + } + gpt2.completion.observe(this) { completion -> + gpt2.prompt.observe(this) { prompt -> + binding.prompt.formatCompletion(prompt, completion) + } + } + } - // LiveData needs the lifecycle owner - binding.lifecycleOwner = this + private fun TextView.formatCompletion(prompt: String, completion: String) { + text = when { + completion.isEmpty() -> prompt + else -> { + val str = SpannableStringBuilder(prompt + completion) + val bgCompletionColor = + ResourcesCompat.getColor(resources, R.color.colorPrimary, context.theme) + str.apply { + setSpan( + android.text.style.BackgroundColorSpan(bgCompletionColor), + prompt.length, + str.length, + Spannable.SPAN_EXCLUSIVE_EXCLUSIVE + ) + } + } + } } + } diff --git a/gpt2/src/main/java/co/huggingface/android_transformers/gpt2/ml/GPT2Client.kt b/gpt2/src/main/java/co/huggingface/android_transformers/gpt2/ml/GPT2Client.kt index 62abdaa..956072a 100644 --- a/gpt2/src/main/java/co/huggingface/android_transformers/gpt2/ml/GPT2Client.kt +++ b/gpt2/src/main/java/co/huggingface/android_transformers/gpt2/ml/GPT2Client.kt @@ -1,16 +1,18 @@ package co.huggingface.android_transformers.gpt2.ml import android.app.Application -import android.text.Spannable -import android.text.SpannableStringBuilder import android.util.JsonReader -import android.widget.TextView -import androidx.core.content.res.ResourcesCompat -import androidx.databinding.BindingAdapter -import androidx.lifecycle.* -import co.huggingface.android_transformers.gpt2.R +import androidx.lifecycle.AndroidViewModel +import androidx.lifecycle.LiveData +import androidx.lifecycle.MutableLiveData +import androidx.lifecycle.viewModelScope import co.huggingface.android_transformers.gpt2.tokenization.GPT2Tokenizer -import kotlinx.coroutines.* +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.cancelAndJoin +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import kotlinx.coroutines.yield import org.tensorflow.lite.Interpreter import java.io.BufferedReader import java.io.FileInputStream @@ -107,7 +109,7 @@ class GPT2Client(application: Application) : AndroidViewModel(application) { // Softmax computation on filtered logits val filteredLogits = filteredLogitsWithIndexes.map { it.second } - val maxLogitValue = filteredLogits.max()!! + val maxLogitValue = filteredLogits.maxOrNull()!! val logitsExp = filteredLogits.map { exp(it - maxLogitValue) } val sumExp = logitsExp.sum() val probs = logitsExp.map { it.div(sumExp) } @@ -200,17 +202,3 @@ private fun FloatArray.argmax(): Int { return bestIndex } - -@BindingAdapter("prompt", "completion") -fun TextView.formatCompletion(prompt: String, completion: String): Unit { - text = when { - completion.isEmpty() -> prompt - else -> { - val str = SpannableStringBuilder(prompt + completion) - val bgCompletionColor = ResourcesCompat.getColor(resources, R.color.colorPrimary, context.theme) - str.setSpan(android.text.style.BackgroundColorSpan(bgCompletionColor), prompt.length, str.length, Spannable.SPAN_EXCLUSIVE_EXCLUSIVE) - - str - } - } -} diff --git a/gpt2/src/main/java/co/huggingface/android_transformers/gpt2/tokenization/GPT2Tokenizer.kt b/gpt2/src/main/java/co/huggingface/android_transformers/gpt2/tokenization/GPT2Tokenizer.kt index 0899b89..6f08009 100644 --- a/gpt2/src/main/java/co/huggingface/android_transformers/gpt2/tokenization/GPT2Tokenizer.kt +++ b/gpt2/src/main/java/co/huggingface/android_transformers/gpt2/tokenization/GPT2Tokenizer.kt @@ -36,7 +36,7 @@ class GPT2Tokenizer( while (true) { if (!pairs.any { bpeRanks.containsKey(it) }) break - val (first, second) = pairs.minBy { bpeRanks.getOrDefault(it, Int.MAX_VALUE) } ?: break + val (first, second) = pairs.minByOrNull { bpeRanks.getOrDefault(it, Int.MAX_VALUE) } ?: break var i = 0 val newWord = mutableListOf() diff --git a/gpt2/src/main/res/layout/activity_main.xml b/gpt2/src/main/res/layout/activity_main.xml index b994719..5bd56fa 100644 --- a/gpt2/src/main/res/layout/activity_main.xml +++ b/gpt2/src/main/res/layout/activity_main.xml @@ -1,61 +1,45 @@ - - - - - - - - - - - - - - - - - - + xmlns:tools="http://schemas.android.com/tools" + android:layout_width="match_parent" + android:layout_height="match_parent" + tools:context=".MainActivity"> + + + + + + + + diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index a95fa1a..43cadad 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -3,4 +3,4 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-5.4.1-all.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-7.4-all.zip