小男孩‘自慰网亚洲一区二区,亚洲一级在线播放毛片,亚洲中文字幕av每天更新,黄aⅴ永久免费无码,91成人午夜在线精品,色网站免费在线观看,亚洲欧洲wwwww在线观看

分享

如何將訓練好的pytorch模型部署到安卓設備上

 520jefferson 2022-04-10

編輯:學姐

這篇文章演示如何將訓練好的pytorch模型部署到安卓設備上。我也是剛開始學安卓,代碼寫的簡單。

環(huán)境:pytorch版本:1.10.0

# 模型轉(zhuǎn)化

pytorch_android支持的模型是.pt模型,我們訓練出來的模型是.pth。所以需要轉(zhuǎn)化才可以用。

先看官網(wǎng)上給的轉(zhuǎn)化方式:

import torchimport torchvisionfrom torch.utils.mobile_optimizer import optimize_for_mobile
model = torchvision.models.mobilenet_v3_small(pretrained=True)model.eval()example = torch.rand(1, 3, 224, 224)traced_script_module = torch.jit.trace(model, example)optimized_traced_model = optimize_for_mobile(traced_script_module)optimized_traced_model._save_for_lite_interpreter('app/src/main/assets/model.ptl')

這個模型在安卓對應的包:

repositories {    jcenter()}
dependencies { implementation 'org.pytorch:pytorch_android_lite:1.9.0' implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'}

注:pytorch_android_lite版本和轉(zhuǎn)化模型用的版本要一致,不一致就會報各種錯誤。

目前用這種方法有點問題,我采用的另一種方法。

轉(zhuǎn)化代碼如下:

import torchimport torch.utils.data.distributed
# pytorch環(huán)境中model_pth = 'model_31_0.96.pth' #模型的參數(shù)文件mobile_pt ='model.pt' # 將模型保存為Android可以調(diào)用的文件
model = torch.load(model_pth)model.eval() # 模型設為評估模式device = torch.device('cpu')model.to(device)# 1張3通道224*224的圖片input_tensor = torch.rand(1, 3, 224, 224) # 設定輸入數(shù)據(jù)格式
mobile = torch.jit.trace(model, input_tensor) # 模型轉(zhuǎn)化mobile.save(mobile_pt) # 保存文件

定義模型文件和轉(zhuǎn)化后的文件路徑。

load模型。(這里要注意,如果保存模型)

torch.save(model,'models.pth')

加載模型則是

model=torch.load('models.pth')

如果保存模型是

torch.save(model.state_dict(),'models.pth')

加載模型則是

model.load_state_dict(torch.load('models.pth'))

定義輸入數(shù)據(jù)格式。

模型轉(zhuǎn)化,然后再保存模型。

# 安卓部署
新建項目


新建安卓項目,選擇Empy Activity,然后選擇Next

Image

然后,填寫項目信息,選擇安卓版本,我用的4.4,點擊完成

Image

導入包

導入pytorch_android的包

//pytorchimplementation 'org.pytorch:pytorch_android:1.10.0'implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'

Image

如果有參數(shù)報錯請參照我的完整的配置,代碼如下:

plugins { id 'com.android.application'}
android { compileSdk 32
defaultConfig { applicationId 'com.example.myapplication' minSdk 21 targetSdk 32 versionCode 1 versionName '1.0'
testInstrumentationRunner 'androidx.test.runner.AndroidJUnitRunner' }
buildTypes { release { minifyEnabled false proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' } } compileOptions { sourceCompatibility JavaVersion.VERSION_1_8 targetCompatibility JavaVersion.VERSION_1_8 }}
dependencies {
implementation 'androidx.appcompat:appcompat:1.3.0' implementation 'com.google.android.material:material:1.4.0' implementation 'androidx.constraintlayout:constraintlayout:2.0.4' testImplementation 'junit:junit:4.13.2' androidTestImplementation 'androidx.test.ext:junit:1.1.3' androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0' //pytorch implementation 'org.pytorch:pytorch_android:1.10.0' implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'
}
頁面文件

頁面的配置如下:

<?xml version='1.0' encoding='utf-8'?><FrameLayout xmlns:android='http://schemas./apk/res/android'    xmlns:tools='http://schemas./tools'    android:layout_width='match_parent'    android:layout_height='match_parent'    tools:context='.MainActivity'>
<ImageView android:id='@+id/image' android:layout_width='match_parent' android:layout_height='match_parent' android:scaleType='fitCenter' />
<TextView android:id='@+id/text' android:layout_width='match_parent' android:layout_height='wrap_content' android:layout_gravity='top' android:textSize='24sp' android:background='#80000000' android:textColor='@android:color/holo_red_light' />
</FrameLayout>

這個頁面只有兩個空間,一個展示圖片,一個顯示文字。

Image

模型推理

新增assets文件夾,然后將轉(zhuǎn)化的模型和待測試的圖片放進去。

Image

新增ImageNetClasses類,這個類存放類別名字。

Image

代碼如下:

package com.example.myapplication;
public class ImageNetClasses { public static String[] IMAGENET_CLASSES = new String[]{ 'Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed', 'Common wheat', 'Fat Hen', 'Loose Silky-bent', 'Maize', 'Scentless Mayweed', 'Shepherds Purse', 'Small-flowered Cranesbill', 'Sugar beet',
};}

在MainActivity類中,增加模型推理的邏輯。

完成代碼如下:

package com.example.myapplication;
import android.content.Context;import android.graphics.Bitmap;import android.graphics.BitmapFactory;import android.os.Bundle;import android.util.Log;import android.widget.ImageView;import android.widget.TextView;
import org.pytorch.IValue;
import org.pytorch.Module;import org.pytorch.Tensor;import org.pytorch.torchvision.TensorImageUtils;import org.pytorch.MemoryFormat;import java.io.File;import java.io.FileOutputStream;import java.io.IOException;import java.io.InputStream;import java.io.OutputStream;
import androidx.appcompat.app.AppCompatActivity;
public class MainActivity extends AppCompatActivity {
@Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main);
Bitmap bitmap = null; Module module = null; try { // creating bitmap from packaged into app android asset 'image.jpg', // app/src/main/assets/image.jpg bitmap = BitmapFactory.decodeStream(getAssets().open('1.png')); // loading serialized torchscript module from packaged into app android asset model.pt, // app/src/model/assets/model.pt module = Module.load(assetFilePath(this, 'models.pt')); } catch (IOException e) { Log.e('PytorchHelloWorld', 'Error reading assets', e); finish(); }
// showing image on UI ImageView imageView = findViewById(R.id.image); imageView.setImageBitmap(bitmap);
// preparing input tensor final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap, TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST);
// running the model final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
// getting tensor content as java array of floats final float[] scores = outputTensor.getDataAsFloatArray();
// searching for the index with maximum score float maxScore = -Float.MAX_VALUE; int maxScoreIdx = -1; for (int i = 0; i < scores.length; i++) { if (scores[i] > maxScore) { maxScore = scores[i]; maxScoreIdx = i; } } System.out.println(maxScoreIdx); String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
// showing className on UI TextView textView = findViewById(R.id.text); textView.setText(className); }
/** * Copies specified asset to the file in /files app directory and returns this file absolute path. * * @return absolute file path */ public static String assetFilePath(Context context, String assetName) throws IOException { File file = new File(context.getFilesDir(), assetName); if (file.exists() && file.length() > 0) { return file.getAbsolutePath(); }
try (InputStream is = context.getAssets().open(assetName)) { try (OutputStream os = new FileOutputStream(file)) { byte[] buffer = new byte[4 * 1024]; int read; while ((read = is.read(buffer)) != -1) { os.write(buffer, 0, read); } os.flush(); } return file.getAbsolutePath(); } }}

然后運行。

Image

    本站是提供個人知識管理的網(wǎng)絡存儲空間,所有內(nèi)容均由用戶發(fā)布,不代表本站觀點。請注意甄別內(nèi)容中的聯(lián)系方式、誘導購買等信息,謹防詐騙。如發(fā)現(xiàn)有害或侵權(quán)內(nèi)容,請點擊一鍵舉報。
    轉(zhuǎn)藏 分享 獻花(0

    0條評論

    發(fā)表

    請遵守用戶 評論公約

    類似文章 更多