Skip to main content
  1. Posts/

Java使用pytorch模型进行数据推算

·422 words·2 mins
Blogs Java Deep-Java-Library Pytorch
Gwen0x4c3
Author
Gwen0x4c3
java n go is my mother tongue
Table of Contents

我的Java后台需要对数据进行分析,但找不到合适的方法,就准备用pytorch写个模型凑活着用。

使用的DJL调用pytorch引擎

Github:djl/README.md at master · deepjavalibrary/djl · GitHub

pom.xml中添加依赖:

    <dependency>
        <groupId>ai.djl.pytorch</groupId>
        <artifactId>pytorch-engine</artifactId>
        <version>0.16.0</version>
    </dependency>
    <dependency>
        <groupId>ai.djl.pytorch</groupId>
        <artifactId>pytorch-native-auto</artifactId>
        <version>1.9.1</version>
        <scope>runtime</scope>
    </dependency>
    <dependency>
        <groupId>ai.djl.pytorch</groupId>
        <artifactId>pytorch-jni</artifactId>
        <version>1.9.1-0.16.0</version>
        <scope>runtime</scope>
    </dependency>

注意version与pytorch版本有一个对应关系

PyTorch engine version PyTorch native library version
pytorch-engine:0.15.0 pytorch-native-auto: 1.8.1, 1.9.1, 1.10.0
pytorch-engine:0.14.0 pytorch-native-auto: 1.8.1, 1.9.0, 1.9.1
pytorch-engine:0.13.0 pytorch-native-auto:1.9.0
pytorch-engine:0.12.0 pytorch-native-auto:1.8.1
pytorch-engine:0.11.0 pytorch-native-auto:1.8.1
pytorch-engine:0.10.0 pytorch-native-auto:1.7.1
pytorch-engine:0.9.0 pytorch-native-auto:1.7.0
pytorch-engine:0.8.0 pytorch-native-auto:1.6.0
pytorch-engine:0.7.0 pytorch-native-auto:1.6.0
pytorch-engine:0.6.0 pytorch-native-auto:1.5.0
pytorch-engine:0.5.0 pytorch-native-auto:1.4.0
pytorch-engine:0.4.0 pytorch-native-auto:1.4.0

其他问题访问连接:PyTorch Engine - Deep Java Library


官方给出了一个图片分类的例子,我只需要纯数据不需要图片输入。

随便写了个例子 输入是[a, b] 输出一个0~1的数

还是建议用python先训练好模型,不要用Java训练。模型训练好后,首先要做的是把pytorch模型转为TorchScript,TorchScript会把模型结构和参数都加载进去的

官网原文:

There are two ways to convert your model to TorchScript: tracing and scripting. We will only demonstrate the first one, tracing, but you can find information about scripting from the PyTorch documentation. When tracing, we use an example input to record the actions taken and capture the the model architecture. This works best when your model doesn’t have control flow. If you do have control flow, you will need to use the scripting approach. In DJL, we use tracing to create TorchScript for our ModelZoo models.

Here is an example of tracing in actions:

    import torch
    import torchvision
    
    # An instance of your model.
    model = torchvision.models.resnet18(pretrained=True)
    
    # Switch the model to eval model
    model.eval()
    
    # An example input you would normally provide to your model's forward() method.
    example = torch.rand(1, 3, 224, 224)
    
    # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
    traced_script_module = torch.jit.trace(model, example)
    
    # Save the TorchScript model
    traced_script_module.save("traced_resnet_model.pt")

如果你使用了dropout等 一定要记得加上model.eval()再保存

对于我的来说 就下面这样

    model = LinearModel()
    
    model.load_state_dict(torch.load("model.pth"))
    
    input = torch.tensor([0.72, 0.94]).float() //根据你的模型随便创建一个输入
        
    script = torch.jit.trace(model, input)
        
    script.save("model.pt")

然后该写Java代码了

官网例子:Load a PyTorch Model - Deep Java Library

还有这个:03 image classification with your model - Deep Java Library

我的数据就不需要transform了 代码:

    //首先创建一个模型
    Model model = Model.newInstance("test");
            try {
                model.load(Paths.get("C:\\Users\\Administrator\\IdeaProjects\\PytorchInJava\\src\\main\\resources\\model.pt"));
                System.out.println(model);
    
                //Predictor<参数类型,返回值类型> 输入图片的话参数是Image
                //我的参数是float32 不要写成Double
                Predictor<float[], Object> objectObjectPredictor = model.newPredictor(new NoBatchifyTranslator<float[], Object>() {
                    @Override
                    public NDList processInput(TranslatorContext translatorContext, float[] input) throws Exception {
                        NDManager ndManager = translatorContext.getNDManager();
                        NDArray ndArray = ndManager.create(input);
                        //ndArray作为输入
                        System.out.println(ndArray);
                        return new NDList(ndArray);
                    }
                    @Override
                    public Object processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {
                        System.out.println("process: " + ndList.get(0).getFloat());
                        return ndList.get(0).getFloat();
                    }
                });
    
                float result = objectObjectPredictor.predict(new float[]{0.6144011f, 0.952401f});
    
                System.out.println("result: " + result);
            } catch (IOException e) {
                e.printStackTrace();
            } catch (MalformedModelException e) {
                e.printStackTrace();
            } catch (Exception e) {
                System.out.println("qunimade ");
                e.printStackTrace();
            }

输出:

更新
#

当我打包成jar到centos7的linux中运行时,报错UnsatisfiedLinkError,经过大神的指导,问题出在我引的依赖。

修改后的依赖:

        <properties>
            <java.version>8</java.version>
            <jna.version>5.3.0</jna.version>
        </properties>


<dependencies>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.16.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu-precxx11</artifactId>
<classifier>linux-x86_64</classifier>
<version>1.9.1</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-jni</artifactId>
<version>1.9.1-0.16.0</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
</dependencies>