
「导语」TensorFlow Serving 提供了 GRPC 接口来高效地完成对模型的预测请求,但是它本身只提供了基于 Python 的 API ,如果我们要使用其它语言进行 GRPC 访问,则需手动生成相应的 GRPC 接口文件方可。本文主要介绍使用 protoc 工具生成 TensorFlow Serving API 文件的方式与方法,并且提供完整的项目示例以供参考。
ProtoBuf 文件编译
TensorFlow Serving 是基于 Protocol Buffer 协议来进行 GPRC 通信的,其源码中以 proto 为后缀的文件里定义了一系列数据结构 (message) 以及 RPC 服务 (service) ,它们分别用来表示数据交换的格式以及执行远程操作的接口。 proto 文件本身并不能直接在代码中使用,它们需要进一步转为语言相关的代码文件才能被正常编译与运行。
Protocol Buffer 官方提供了 Protocol Buffer Compiler (protoc) 编译工具来对 proto 文件进行编译并生成语言相关的代码文件,该工具目前支持多种语言的代码生成工作,包括 golang , java , c++ 以及 c# 等。使用 protoc 工具可以极大地减少我们编码的工作量,使我们能够更加专注于具体的业务实现,而无需为定义各种语言相关的数据结构而苦恼。
因此,在使用其它语言与 TensorFlow Serving 进行 GRPC 通信时,我们需要借助 protoc 工具来生成语言相关的 API 文件,以供后续使用。需要注意的是,由于 TensorFlow Serving 源码中的部分 proto 文件需要依赖于 TensorFlow 中的 proto 文件,所以我们需要同时使用两者的源码来生成所需的 API 文件。
下面来简要介绍下 protoc 工具在 Linux 系统下的安装流程:
首先在
Protocol Buffer的Github软件发布页面下载最新版本的protoc二进制压缩包文件,或者使用如下命令直接下载。wget https://github.com/protocolbuffers/protobuf/releases/download/v3.12.3/protoc-3.12.3-linux-x86_64.zip
然后将该压缩包解压到
/usr/local/protoc目录下。unzip protoc-3.12.3-linux-x86_64.zip -d /usr/local/protoc
接着将
/usr/local/protoc/bin目录加入到PATH环境变量中。可以将下面这行语句加入到/etc/profile文件中来达成上述目标。export PATH=$PATH:/usr/local/protoc/bin
最后测试安装成功
protoc --version
API 文件生成及使用
一般而言,只要是 Protocol Buffer 和 GRPC 支持的语言,都可以生成 TensorFlow Serving 的 API 文件。在上一篇文章TensorFlow 2.x 模型 Serving 服务中我已经介绍过使用 Python 来进行 GPRC 请求的示例,本文则主要介绍使用 Golang 以及 Java 来生成 TensorFlow Serving 的 API 文件以及进行 GPRC 请求的方法。
为了生成可执行的代码文件,我们首先需要将 TensorFlow 和 TensorFlow Serving 的源码都 clone 到本地。
mkdir tensorflow-serving-api && cd tensorflow-serving-api |
接下来就可以使用源码中的 proto 文件来生成相应语言的 TensorFlow Serving API 文件了。
Golang
在生成 golang 相关的 API 代码文件时,我们需要安装 golang 环境以及一些 protoc 插件来辅助我们进行文件生成操作。以下的相关操作均在 Linux 系统完成。
安装
golang,流程如下所示。wget https://dl.google.com/go/go1.14.4.linux-amd64.tar.gz
tar zxvf go1.14.4.linux-amd64.tar.gz -C /usr/local
export PATH=$PATH:/usr/local/go/bin
go version安装
protoc-gen-go插件,用于生成go文件。go get -u google.golang.org/protobuf/cmd/protoc-gen-go
or
go get -u github.com/golang/protobuf/protoc-gen-goprotoc-gen-go默认会安装在$GOPATH/bin目录下,需要确保该目录在PATH下以使得protoc工具能够找到该插件。安装
grpc插件,用于生成grpc go文件。go get -u google.golang.org/grpc
将源码切换到指定分支或标签。
cd tensorflow-serving-api/tensorflow
git checkout tags/v2.2.0
cd tensorflow-serving-api/serving
git checkout tags/2.2.0使用
protoc工具生成go文件。cd tensorflow-serving-api
protoc -I=serving -I=tensorflow --go_out=plugins=grpc:golang serving/tensorflow_serving/*/*.proto
protoc -I=serving -I=tensorflow --go_out=plugins=grpc:golang serving/tensorflow_serving/sources/storage_path/*.proto
protoc -I=serving -I=tensorflow --go_out=plugins=grpc:golang tensorflow/tensorflow/core/framework/*.proto
protoc -I=serving -I=tensorflow --go_out=plugins=grpc:golang tensorflow/tensorflow/core/example/*.proto
protoc -I=serving -I=tensorflow --go_out=plugins=grpc:golang tensorflow/tensorflow/core/protobuf/*.proto
protoc -I=serving -I=tensorflow --go_out=plugins=grpc:golang tensorflow/tensorflow/stream_executor/*.proto其中
-I指定了proto文件搜索依赖文件的路径,可以指定多次。--go_out指定了保存go文件的目录(这里为golang)以及使用的grpc插件。命令的最后一项以通配符的形式指定了proto文件的输入位置。至于为何选择上述的
proto文件进行API生成,是根据实际使用情况以及proto文件之间的依赖关系决定的。可以先从serving的proto源码入手,并参照其Python GRPC示例的代码实现,找到入口的proto文件,然后根据其本身及依赖的proto文件生成相应的API代码文件,接着进行编码测试,查缺补漏,直到所有代码文件编译无误为止。执行完上述命令后会在
golang目录下生成两个目录,分别为github.com和tensorflow_serving,前者包含有从tensorflow源码中的proto文件生成的go文件,后者包含从serving源码中的proto文件生成的go文件。因为tensorflow源码中的proto文件均包含go_package选项如option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_go_proto";,它们指定了生成的go文件的输出目录,所以其生成的go文件会在github.com目录下,而serving源码中的proto文件不包含该选项,所以go文件的输出目录默认与源码文件的目录相同。tensorflow_serving目录下生成的go文件中可能会出现循环引用错误,如下所示:import cycle not allowed
package github.com/alex/tensorflow-serving-api-go
imports tensorflow_serving/apis
imports tensorflow_serving/core
imports tensorflow_serving/apis此时你需要将
tensorflow_serving/core目录下logging.pb.go文件和tensorflow_serving/apis目录下的prediction_log.pb.go文件删除以解决上述问题。删除上述代码文件并不影响后续的GRPC模型预测请求。假设我有一个名为
first_model的模型部署在了TensorFlow Serving服务上,它的元数据信息如下所示:curl http://localhost:8501/v1/models/first_model/versions/0/metadata
{
"model_spec": {
"name": "first_model",
"signature_name": "",
"version": "0"
},
"metadata": {
"signature_def": {
"signature_def": {
"serving_default": {
"inputs": {
"input_1": {
"dtype": "DT_INT64",
"tensor_shape": {
"dim": [
{
"size": "-1",
"name": ""
},
{
"size": "31",
"name": ""
}
],
"unknown_rank": false
},
"name": "serving_default_input_1:0"
}
},
"outputs": {
"output_1": {
"dtype": "DT_FLOAT",
"tensor_shape": {
"dim": [
{
"size": "-1",
"name": ""
},
{
"size": "1",
"name": ""
}
],
"unknown_rank": false
},
"name": "StatefulPartitionedCall:0"
}
},
"method_name": "tensorflow/serving/predict"
},
"__saved_model_init_op": {
"inputs": {},
"outputs": {
"__saved_model_init_op": {
"dtype": "DT_INVALID",
"tensor_shape": {
"dim": [],
"unknown_rank": true
},
"name": "NoOp"
}
},
"method_name": ""
}
}
}
}
}我们需要重点关注上述信息中的
inputs选项,它定义了该模型输入数据的key值(这里为input_1) 、输入数据的维度(这里为(-1, 31))以及输入数据的类型 (这里为DT_INT64)。在进行GRPC预测请求时,代码中指定的输入数据需要与元数据中定义的各种输入信息相匹配,否则就无法获取正确的模型输出。创建
go项目并向first_model发送GRPC预测请求。具体的项目详情请参见Github上的实现及说明,这里只列出主函数的代码,如下所示:package main
import (
"context"
"log"
apis "tensorflow_serving/apis"
"time"
"github.com/golang/protobuf/ptypes/wrappers"
"github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_go_proto"
"github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_shape_go_proto"
"github.com/tensorflow/tensorflow/tensorflow/go/core/framework/types_go_proto"
"google.golang.org/grpc"
)
var (
// TensorFlow serving grpc address.
address = "127.0.0.1:8500"
)
func main() {
// Create a grpc request.
request := &apis.PredictRequest{
ModelSpec: &apis.ModelSpec{},
Inputs: make(map[string]*tensor_go_proto.TensorProto),
}
request.ModelSpec.Name = "first_model"
request.ModelSpec.SignatureName = "serving_default"
// request.ModelSpec.VersionChoice = &apis.ModelSpec_VersionLabel{VersionLabel: "stable"}
request.ModelSpec.VersionChoice = &apis.ModelSpec_Version{Version: &wrappers.Int64Value{Value: 0}}
request.Inputs["input_1"] = &tensor_go_proto.TensorProto{
Dtype: types_go_proto.DataType_DT_INT64,
TensorShape: &tensor_shape_go_proto.TensorShapeProto{
Dim: []*tensor_shape_go_proto.TensorShapeProto_Dim{
&tensor_shape_go_proto.TensorShapeProto_Dim{
Size: int64(2),
},
&tensor_shape_go_proto.TensorShapeProto_Dim{
Size: int64(31),
},
},
},
Int64Val: []int64{
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
},
}
// Create a grpc connection.
conn, err := grpc.Dial(address, grpc.WithInsecure(), grpc.WithBlock(), grpc.WithTimeout(10*time.Second))
if err != nil {
log.Fatalf("couldn't connect: %s", err.Error())
}
defer conn.Close()
// Wrap the grpc uri with client.
client := apis.NewPredictionServiceClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
// Send the grpc request.
response, err := client.Predict(ctx, request)
if err != nil {
log.Fatalf("couldn't get response: %v", err)
}
log.Printf("%+v", response)
}Github地址为:https://github.com/AlexanderJLiu/tensorflow-serving-api/tree/master/golang
Java
在生成 java 相关的 API 代码文件时,我们需要安装 java 环境以及一些 protoc 插件来辅助我们进行文件生成操作。以下相关操作均在 Linux 系统完成。
安装
OpenJDK。centos
yum-config-manager --enable rhel-7-server-optional-rpms
yum install java-11-openjdk-devel
ubuntu
apt-get install openjdk-11-jdk
test
java -version安装
protoc-gen-grpc-java,用于生成grpc java文件。wget https://repo1.maven.org/maven2/io/grpc/protoc-gen-grpc-java/1.30.2/protoc-gen-grpc-java-1.30.2-linux-x86_64.exe
mv protoc-gen-grpc-java-1.30.2-linux-x86_64.exe /usr/local/protoc/bin/protoc-gen-grpc-java将源码切换到指定分支或标签。
cd tensorflow-serving-api/tensorflow
git checkout tags/v2.2.0
cd tensorflow-serving-api/serving
git checkout tags/2.2.0使用
protoc工具生成java文件。cd tensorflow-serving-api
protoc -I=serving -I=tensorflow --plugin=/usr/local/protoc/bin/protoc-gen-grpc-java --grpc-java_out=java --java_out=java serving/tensorflow_serving/*/*.proto
protoc -I=serving -I=tensorflow --plugin=/usr/local/protoc/bin/protoc-gen-grpc-java --grpc-java_out=java --java_out=java serving/tensorflow_serving/sources/storage_path/*.proto其中
-I指定了proto文件搜索依赖文件的路径,可以指定多次。--plugin指定了要使用的grpc插件的路径。--grpc-java_out指定了grpc java文件的保存目录。--java_out指定了java文件的保存目录。命令的最后一项以通配符的形式指定了proto文件的输入位置。由于
TensorFlow官方已经基于proto文件生成了TensorFlow的java文件,因此我们就没有必要自己生成了,在使用时直接从maven仓库引入即可:implementation("org.tensorflow:proto:1.15.0")。创建
java项目并向first_model发送GRPC预测请求。具体的项目详情请参见Github上的实现及说明,这里只列出主函数的代码,如下所示:package com.github.alex;
import java.util.Arrays;
import java.util.concurrent.TimeUnit;
import com.google.protobuf.Int64Value;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;
import org.tensorflow.framework.TensorShapeProto.Dim;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.stub.StreamObserver;
import tensorflow.serving.Model.ModelSpec;
import tensorflow.serving.Predict.PredictRequest;
import tensorflow.serving.Predict.PredictResponse;
import tensorflow.serving.PredictionServiceGrpc;
import tensorflow.serving.PredictionServiceGrpc.PredictionServiceBlockingStub;
import tensorflow.serving.PredictionServiceGrpc.PredictionServiceStub;
public class App {
public String getGreeting() {
return "Hello world.";
}
public static void main(String[] args) {
PredictRequest.Builder requestBuilder = PredictRequest.newBuilder();
ModelSpec.Builder modelSpecBuilder = ModelSpec.newBuilder();
modelSpecBuilder.setSignatureName("serving_default");
modelSpecBuilder.setName("first_model");
modelSpecBuilder.setVersion(Int64Value.newBuilder().setValue(0L));
requestBuilder.setModelSpec(modelSpecBuilder.build());
TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
tensorProtoBuilder.setDtype(DataType.DT_INT64);
Dim[] dim = {Dim.newBuilder().setSize(1).build(), Dim.newBuilder().setSize(31).build()};
tensorProtoBuilder
.setTensorShape(TensorShapeProto.newBuilder().addAllDim(Arrays.asList(dim)));
// tensorProtoBuilder.setTensorShape(TensorShapeProto.newBuilder()
// .addDim(Dim.newBuilder().setSize(1)).addDim(Dim.newBuilder().setSize(31)));
Long[] inputs = {1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L,
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L};
tensorProtoBuilder.addAllInt64Val(Arrays.asList(inputs));
requestBuilder.putInputs("input_1", tensorProtoBuilder.build());
PredictRequest request = requestBuilder.build();
System.out.println(request);
String target = "127.0.0.1:8500";
// Create a communication channel to the server, known as a Channel. Channels are
// thread-safe and reusable. It is common to create channels at the beginning of your
// application and reuse them until the application shuts down.
ManagedChannel channel = ManagedChannelBuilder.forTarget(target)
// Channels are secure by default (via SSL/TLS). For the example we disable TLS to
// avoid needing certificates.
.usePlaintext().build();
PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel);
try {
PredictResponse response = stub.predict(request);
System.out.println(response);
channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}Github地址为:https://github.com/AlexanderJLiu/tensorflow-serving-api/tree/master/java
其他语言
我在 Github 上创建了一个名为 tensorflow-serving-api 项目,该项目旨在生成 Protocol Buffer 和 GRPC 所支持的所有语言的 TensorFlow Serving API 文件,并且以完整项目的形式给出使用的示例。
该项目正在逐渐完善中,目前已经实现了 Golang , Java 以及 Python 语言的 TensorFlow Serving API 和项目示例,后续还会加入更多的语言实现,也欢迎大家来共同参与贡献。
项目链接地址:https://github.com/AlexanderJLiu/tensorflow-serving-api
参考资料
- Go Installation Instructions
- Protocol Buffer Basics: Go
- ProtoBuf: Go Generated Code
- Go GRPC Examples
- Install OpenJDK on Windows and Linux
- Protocol Buffer Basics: Java
- ProtoBuf: Java Generated Code
- GRPC: Java Generated-code Reference
