逑识

吾生也有涯,而知也无涯,以无涯奉有涯,其易欤?

0%

TensorFlow Serving API

导语」TensorFlow Serving 提供了 GRPC 接口来高效地完成对模型的预测请求,但是它本身只提供了基于 Python 的 API ,如果我们要使用其它语言进行 GRPC 访问,则需手动生成相应的 GRPC 接口文件方可。本文主要介绍使用 protoc 工具生成 TensorFlow Serving API 文件的方式与方法,并且提供完整的项目示例以供参考。

ProtoBuf 文件编译

TensorFlow Serving 是基于 Protocol Buffer 协议来进行 GPRC 通信的,其源码中以 proto 为后缀的文件里定义了一系列数据结构 (message) 以及 RPC 服务 (service) ,它们分别用来表示数据交换的格式以及执行远程操作的接口。 proto 文件本身并不能直接在代码中使用,它们需要进一步转为语言相关的代码文件才能被正常编译与运行。

Protocol Buffer 官方提供了 Protocol Buffer Compiler (protoc) 编译工具来对 proto 文件进行编译并生成语言相关的代码文件,该工具目前支持多种语言的代码生成工作,包括 golangjavac++ 以及 c# 等。使用 protoc 工具可以极大地减少我们编码的工作量,使我们能够更加专注于具体的业务实现,而无需为定义各种语言相关的数据结构而苦恼。

因此,在使用其它语言与 TensorFlow Serving 进行 GRPC 通信时,我们需要借助 protoc 工具来生成语言相关的 API 文件,以供后续使用。需要注意的是,由于 TensorFlow Serving 源码中的部分 proto 文件需要依赖于 TensorFlow 中的 proto 文件,所以我们需要同时使用两者的源码来生成所需的 API 文件。

下面来简要介绍下 protoc 工具在 Linux 系统下的安装流程:

  1. 首先在 Protocol BufferGithub 软件发布页面下载最新版本的 protoc 二进制压缩包文件,或者使用如下命令直接下载。

    wget https://github.com/protocolbuffers/protobuf/releases/download/v3.12.3/protoc-3.12.3-linux-x86_64.zip
  2. 然后将该压缩包解压到 /usr/local/protoc 目录下。

    unzip protoc-3.12.3-linux-x86_64.zip -d /usr/local/protoc
  3. 接着将 /usr/local/protoc/bin 目录加入到 PATH 环境变量中。可以将下面这行语句加入到 /etc/profile 文件中来达成上述目标。

    export PATH=$PATH:/usr/local/protoc/bin
  4. 最后测试安装成功

    protoc --version

API 文件生成及使用

一般而言,只要是 Protocol BufferGRPC 支持的语言,都可以生成 TensorFlow ServingAPI 文件。在上一篇文章TensorFlow 2.x 模型 Serving 服务中我已经介绍过使用 Python 来进行 GPRC 请求的示例,本文则主要介绍使用 Golang 以及 Java 来生成 TensorFlow ServingAPI 文件以及进行 GPRC 请求的方法。

为了生成可执行的代码文件,我们首先需要将 TensorFlowTensorFlow Serving 的源码都 clone 到本地。

mkdir tensorflow-serving-api && cd tensorflow-serving-api
git clone https://github.com/tensorflow/tensorflow.git
git clone https://github.com/tensorflow/serving.git

接下来就可以使用源码中的 proto 文件来生成相应语言的 TensorFlow Serving API 文件了。

Golang

在生成 golang 相关的 API 代码文件时,我们需要安装 golang 环境以及一些 protoc 插件来辅助我们进行文件生成操作。以下的相关操作均在 Linux 系统完成。

  1. 安装 golang ,流程如下所示。

    wget https://dl.google.com/go/go1.14.4.linux-amd64.tar.gz
    tar zxvf go1.14.4.linux-amd64.tar.gz -C /usr/local
    export PATH=$PATH:/usr/local/go/bin
    go version
  2. 安装 protoc-gen-go 插件,用于生成 go 文件。

    go get -u google.golang.org/protobuf/cmd/protoc-gen-go
    # or
    go get -u github.com/golang/protobuf/protoc-gen-go

    protoc-gen-go 默认会安装在 $GOPATH/bin 目录下,需要确保该目录在 PATH 下以使得 protoc 工具能够找到该插件。

  3. 安装 grpc 插件,用于生成 grpc go 文件。

    go get -u google.golang.org/grpc
  4. 将源码切换到指定分支或标签。

    cd tensorflow-serving-api/tensorflow
    git checkout tags/v2.2.0
    cd tensorflow-serving-api/serving
    git checkout tags/2.2.0
  5. 使用 protoc 工具生成 go 文件。

    cd tensorflow-serving-api
    protoc -I=serving -I=tensorflow --go_out=plugins=grpc:golang serving/tensorflow_serving/*/*.proto
    protoc -I=serving -I=tensorflow --go_out=plugins=grpc:golang serving/tensorflow_serving/sources/storage_path/*.proto
    protoc -I=serving -I=tensorflow --go_out=plugins=grpc:golang tensorflow/tensorflow/core/framework/*.proto
    protoc -I=serving -I=tensorflow --go_out=plugins=grpc:golang tensorflow/tensorflow/core/example/*.proto
    protoc -I=serving -I=tensorflow --go_out=plugins=grpc:golang tensorflow/tensorflow/core/protobuf/*.proto
    protoc -I=serving -I=tensorflow --go_out=plugins=grpc:golang tensorflow/tensorflow/stream_executor/*.proto

    其中 -I 指定了 proto 文件搜索依赖文件的路径,可以指定多次。 --go_out 指定了保存 go 文件的目录(这里为 golang )以及使用的 grpc 插件。命令的最后一项以通配符的形式指定了 proto 文件的输入位置。

    至于为何选择上述的 proto 文件进行 API 生成,是根据实际使用情况以及 proto 文件之间的依赖关系决定的。可以先从 servingproto 源码入手,并参照其 Python GRPC 示例的代码实现,找到入口的 proto 文件,然后根据其本身及依赖的 proto 文件生成相应的 API 代码文件,接着进行编码测试,查缺补漏,直到所有代码文件编译无误为止。

  6. 执行完上述命令后会在 golang 目录下生成两个目录,分别为 github.comtensorflow_serving ,前者包含有从 tensorflow 源码中的 proto 文件生成的 go 文件,后者包含从 serving 源码中的 proto 文件生成的 go 文件。因为 tensorflow 源码中的 proto 文件均包含 go_package 选项如 option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_go_proto"; ,它们指定了生成的 go 文件的输出目录,所以其生成的 go 文件会在 github.com 目录下,而 serving 源码中的 proto 文件不包含该选项,所以 go 文件的输出目录默认与源码文件的目录相同。

  7. tensorflow_serving 目录下生成的 go 文件中可能会出现循环引用错误,如下所示:

    import cycle not allowed
    package github.com/alex/tensorflow-serving-api-go
    imports tensorflow_serving/apis
    imports tensorflow_serving/core
    imports tensorflow_serving/apis

    此时你需要将 tensorflow_serving/core 目录下 logging.pb.go 文件和 tensorflow_serving/apis 目录下的 prediction_log.pb.go 文件删除以解决上述问题。删除上述代码文件并不影响后续的 GRPC 模型预测请求。

  8. 假设我有一个名为 first_model 的模型部署在了 TensorFlow Serving 服务上,它的元数据信息如下所示:

    $ curl http://localhost:8501/v1/models/first_model/versions/0/metadata
    {
    "model_spec": {
    "name": "first_model",
    "signature_name": "",
    "version": "0"
    },
    "metadata": {
    "signature_def": {
    "signature_def": {
    "serving_default": {
    "inputs": {
    "input_1": {
    "dtype": "DT_INT64",
    "tensor_shape": {
    "dim": [
    {
    "size": "-1",
    "name": ""
    },
    {
    "size": "31",
    "name": ""
    }
    ],
    "unknown_rank": false
    },
    "name": "serving_default_input_1:0"
    }
    },
    "outputs": {
    "output_1": {
    "dtype": "DT_FLOAT",
    "tensor_shape": {
    "dim": [
    {
    "size": "-1",
    "name": ""
    },
    {
    "size": "1",
    "name": ""
    }
    ],
    "unknown_rank": false
    },
    "name": "StatefulPartitionedCall:0"
    }
    },
    "method_name": "tensorflow/serving/predict"
    },
    "__saved_model_init_op": {
    "inputs": {},
    "outputs": {
    "__saved_model_init_op": {
    "dtype": "DT_INVALID",
    "tensor_shape": {
    "dim": [],
    "unknown_rank": true
    },
    "name": "NoOp"
    }
    },
    "method_name": ""
    }
    }
    }
    }
    }

    我们需要重点关注上述信息中的 inputs 选项,它定义了该模型输入数据的 key 值(这里为 input_1) 、输入数据的维度(这里为 (-1, 31))以及输入数据的类型 (这里为 DT_INT64 )。在进行 GRPC 预测请求时,代码中指定的输入数据需要与元数据中定义的各种输入信息相匹配,否则就无法获取正确的模型输出。

  9. 创建 go 项目并向 first_model 发送 GRPC 预测请求。具体的项目详情请参见 Github 上的实现及说明,这里只列出主函数的代码,如下所示:

    package main

    import (
    "context"
    "log"
    apis "tensorflow_serving/apis"
    "time"

    "github.com/golang/protobuf/ptypes/wrappers"
    "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_go_proto"
    "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_shape_go_proto"
    "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/types_go_proto"
    "google.golang.org/grpc"
    )

    var (
    // TensorFlow serving grpc address.
    address = "127.0.0.1:8500"
    )

    func main() {
    // Create a grpc request.
    request := &apis.PredictRequest{
    ModelSpec: &apis.ModelSpec{},
    Inputs: make(map[string]*tensor_go_proto.TensorProto),
    }
    request.ModelSpec.Name = "first_model"
    request.ModelSpec.SignatureName = "serving_default"
    // request.ModelSpec.VersionChoice = &apis.ModelSpec_VersionLabel{VersionLabel: "stable"}
    request.ModelSpec.VersionChoice = &apis.ModelSpec_Version{Version: &wrappers.Int64Value{Value: 0}}
    request.Inputs["input_1"] = &tensor_go_proto.TensorProto{
    Dtype: types_go_proto.DataType_DT_INT64,
    TensorShape: &tensor_shape_go_proto.TensorShapeProto{
    Dim: []*tensor_shape_go_proto.TensorShapeProto_Dim{
    &tensor_shape_go_proto.TensorShapeProto_Dim{
    Size: int64(2),
    },
    &tensor_shape_go_proto.TensorShapeProto_Dim{
    Size: int64(31),
    },
    },
    },
    Int64Val: []int64{
    1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
    },
    }

    // Create a grpc connection.
    conn, err := grpc.Dial(address, grpc.WithInsecure(), grpc.WithBlock(), grpc.WithTimeout(10*time.Second))
    if err != nil {
    log.Fatalf("couldn't connect: %s", err.Error())
    }
    defer conn.Close()

    // Wrap the grpc uri with client.
    client := apis.NewPredictionServiceClient(conn)
    ctx, cancel := context.WithTimeout(context.Background(), time.Second)
    defer cancel()
    // Send the grpc request.
    response, err := client.Predict(ctx, request)
    if err != nil {
    log.Fatalf("couldn't get response: %v", err)
    }
    log.Printf("%+v", response)
    }
  10. Github 地址为:https://github.com/AlexanderJLiu/tensorflow-serving-api/tree/master/golang

Java

在生成 java 相关的 API 代码文件时,我们需要安装 java 环境以及一些 protoc 插件来辅助我们进行文件生成操作。以下相关操作均在 Linux 系统完成。

  1. 安装 OpenJDK

    # centos
    yum-config-manager --enable rhel-7-server-optional-rpms
    yum install java-11-openjdk-devel
    # ubuntu
    apt-get install openjdk-11-jdk
    # test
    java -version
  2. 安装 protoc-gen-grpc-java ,用于生成 grpc java 文件。

    wget https://repo1.maven.org/maven2/io/grpc/protoc-gen-grpc-java/1.30.2/protoc-gen-grpc-java-1.30.2-linux-x86_64.exe
    mv protoc-gen-grpc-java-1.30.2-linux-x86_64.exe /usr/local/protoc/bin/protoc-gen-grpc-java
  3. 将源码切换到指定分支或标签。

    cd tensorflow-serving-api/tensorflow
    git checkout tags/v2.2.0
    cd tensorflow-serving-api/serving
    git checkout tags/2.2.0
  4. 使用 protoc 工具生成 java 文件。

    cd tensorflow-serving-api
    protoc -I=serving -I=tensorflow --plugin=/usr/local/protoc/bin/protoc-gen-grpc-java --grpc-java_out=java --java_out=java serving/tensorflow_serving/*/*.proto
    protoc -I=serving -I=tensorflow --plugin=/usr/local/protoc/bin/protoc-gen-grpc-java --grpc-java_out=java --java_out=java serving/tensorflow_serving/sources/storage_path/*.proto

    其中 -I 指定了 proto 文件搜索依赖文件的路径,可以指定多次。 --plugin 指定了要使用的 grpc 插件的路径。 --grpc-java_out 指定了 grpc java 文件的保存目录。 --java_out 指定了 java 文件的保存目录。命令的最后一项以通配符的形式指定了 proto 文件的输入位置。

    由于 TensorFlow 官方已经基于 proto 文件生成了 TensorFlowjava 文件,因此我们就没有必要自己生成了,在使用时直接从 maven 仓库引入即可: implementation("org.tensorflow:proto:1.15.0")

  5. 创建 java 项目并向 first_model 发送 GRPC 预测请求。具体的项目详情请参见 Github 上的实现及说明,这里只列出主函数的代码,如下所示:

    package com.github.alex;

    import java.util.Arrays;
    import java.util.concurrent.TimeUnit;
    import com.google.protobuf.Int64Value;
    import org.tensorflow.framework.DataType;
    import org.tensorflow.framework.TensorProto;
    import org.tensorflow.framework.TensorShapeProto;
    import org.tensorflow.framework.TensorShapeProto.Dim;
    import io.grpc.ManagedChannel;
    import io.grpc.ManagedChannelBuilder;
    import io.grpc.stub.StreamObserver;
    import tensorflow.serving.Model.ModelSpec;
    import tensorflow.serving.Predict.PredictRequest;
    import tensorflow.serving.Predict.PredictResponse;
    import tensorflow.serving.PredictionServiceGrpc;
    import tensorflow.serving.PredictionServiceGrpc.PredictionServiceBlockingStub;
    import tensorflow.serving.PredictionServiceGrpc.PredictionServiceStub;

    public class App {
    public String getGreeting() {
    return "Hello world.";
    }

    public static void main(String[] args) {
    PredictRequest.Builder requestBuilder = PredictRequest.newBuilder();

    ModelSpec.Builder modelSpecBuilder = ModelSpec.newBuilder();
    modelSpecBuilder.setSignatureName("serving_default");
    modelSpecBuilder.setName("first_model");
    modelSpecBuilder.setVersion(Int64Value.newBuilder().setValue(0L));
    requestBuilder.setModelSpec(modelSpecBuilder.build());

    TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
    tensorProtoBuilder.setDtype(DataType.DT_INT64);
    Dim[] dim = {Dim.newBuilder().setSize(1).build(), Dim.newBuilder().setSize(31).build()};
    tensorProtoBuilder
    .setTensorShape(TensorShapeProto.newBuilder().addAllDim(Arrays.asList(dim)));
    // tensorProtoBuilder.setTensorShape(TensorShapeProto.newBuilder()
    // .addDim(Dim.newBuilder().setSize(1)).addDim(Dim.newBuilder().setSize(31)));
    Long[] inputs = {1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L,
    1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L};
    tensorProtoBuilder.addAllInt64Val(Arrays.asList(inputs));
    requestBuilder.putInputs("input_1", tensorProtoBuilder.build());

    PredictRequest request = requestBuilder.build();
    System.out.println(request);

    String target = "127.0.0.1:8500";
    // Create a communication channel to the server, known as a Channel. Channels are
    // thread-safe and reusable. It is common to create channels at the beginning of your
    // application and reuse them until the application shuts down.
    ManagedChannel channel = ManagedChannelBuilder.forTarget(target)
    // Channels are secure by default (via SSL/TLS). For the example we disable TLS to
    // avoid needing certificates.
    .usePlaintext().build();
    PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel);
    try {
    PredictResponse response = stub.predict(request);
    System.out.println(response);
    channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS);
    } catch (InterruptedException e) {
    e.printStackTrace();
    }
    }
    }
  6. Github 地址为:https://github.com/AlexanderJLiu/tensorflow-serving-api/tree/master/java

其他语言

我在 Github 上创建了一个名为 tensorflow-serving-api 项目,该项目旨在生成 Protocol BufferGRPC 所支持的所有语言的 TensorFlow Serving API 文件,并且以完整项目的形式给出使用的示例。

该项目正在逐渐完善中,目前已经实现了 GolangJava 以及 Python 语言的 TensorFlow Serving API 和项目示例,后续还会加入更多的语言实现,也欢迎大家来共同参与贡献。

项目链接地址:https://github.com/AlexanderJLiu/tensorflow-serving-api

参考资料

  1. Go Installation Instructions
  2. Protocol Buffer Basics: Go
  3. ProtoBuf: Go Generated Code
  4. Go GRPC Examples
  5. Install OpenJDK on Windows and Linux
  6. Protocol Buffer Basics: Java
  7. ProtoBuf: Java Generated Code
  8. GRPC: Java Generated-code Reference

推荐阅读

欢迎关注我的其它发布渠道