「导语」TensorFlow Serving 提供了 GRPC 接口来高效地完成对模型的预测请求,但是它本身只提供了基于 Python 的 API ,如果我们要使用其它语言进行 GRPC 访问,则需手动生成相应的 GRPC 接口文件方可。本文主要介绍使用 protoc 工具生成 TensorFlow Serving API 文件的方式与方法,并且提供完整的项目示例以供参考。
ProtoBuf 文件编译
TensorFlow Serving
是基于 Protocol Buffer
协议来进行 GPRC
通信的,其源码中以 proto
为后缀的文件里定义了一系列数据结构 (message
) 以及 RPC
服务 (service
) ,它们分别用来表示数据交换的格式以及执行远程操作的接口。 proto
文件本身并不能直接在代码中使用,它们需要进一步转为语言相关的代码文件才能被正常编译与运行。
Protocol Buffer
官方提供了 Protocol Buffer Compiler
(protoc
) 编译工具来对 proto
文件进行编译并生成语言相关的代码文件,该工具目前支持多种语言的代码生成工作,包括 golang
, java
, c++
以及 c#
等。使用 protoc
工具可以极大地减少我们编码的工作量,使我们能够更加专注于具体的业务实现,而无需为定义各种语言相关的数据结构而苦恼。
因此,在使用其它语言与 TensorFlow Serving
进行 GRPC
通信时,我们需要借助 protoc
工具来生成语言相关的 API
文件,以供后续使用。需要注意的是,由于 TensorFlow Serving
源码中的部分 proto
文件需要依赖于 TensorFlow
中的 proto
文件,所以我们需要同时使用两者的源码来生成所需的 API
文件。
下面来简要介绍下 protoc
工具在 Linux
系统下的安装流程:
首先在
Protocol Buffer
的Github
软件发布页面下载最新版本的protoc
二进制压缩包文件,或者使用如下命令直接下载。wget https://github.com/protocolbuffers/protobuf/releases/download/v3.12.3/protoc-3.12.3-linux-x86_64.zip
然后将该压缩包解压到
/usr/local/protoc
目录下。unzip protoc-3.12.3-linux-x86_64.zip -d /usr/local/protoc
接着将
/usr/local/protoc/bin
目录加入到PATH
环境变量中。可以将下面这行语句加入到/etc/profile
文件中来达成上述目标。export PATH=$PATH:/usr/local/protoc/bin
最后测试安装成功
protoc --version
API 文件生成及使用
一般而言,只要是 Protocol Buffer
和 GRPC
支持的语言,都可以生成 TensorFlow Serving
的 API
文件。在上一篇文章TensorFlow 2.x 模型 Serving 服务中
我已经介绍过使用 Python
来进行 GPRC
请求的示例,本文则主要介绍使用 Golang
以及 Java
来生成 TensorFlow Serving
的 API
文件以及进行 GPRC
请求的方法。
为了生成可执行的代码文件,我们首先需要将 TensorFlow
和 TensorFlow Serving
的源码都 clone
到本地。
mkdir tensorflow-serving-api && cd tensorflow-serving-api |
接下来就可以使用源码中的 proto
文件来生成相应语言的 TensorFlow Serving API
文件了。
Golang
在生成 golang
相关的 API
代码文件时,我们需要安装 golang
环境以及一些 protoc
插件来辅助我们进行文件生成操作。以下的相关操作均在 Linux
系统完成。
安装
golang
,流程如下所示。wget https://dl.google.com/go/go1.14.4.linux-amd64.tar.gz
tar zxvf go1.14.4.linux-amd64.tar.gz -C /usr/local
export PATH=$PATH:/usr/local/go/bin
go version安装
protoc-gen-go
插件,用于生成go
文件。go get -u google.golang.org/protobuf/cmd/protoc-gen-go
or
go get -u github.com/golang/protobuf/protoc-gen-goprotoc-gen-go
默认会安装在$GOPATH/bin
目录下,需要确保该目录在PATH
下以使得protoc
工具能够找到该插件。安装
grpc
插件,用于生成grpc go
文件。go get -u google.golang.org/grpc
将源码切换到指定分支或标签。
cd tensorflow-serving-api/tensorflow
git checkout tags/v2.2.0
cd tensorflow-serving-api/serving
git checkout tags/2.2.0使用
protoc
工具生成go
文件。cd tensorflow-serving-api
protoc -I=serving -I=tensorflow --go_out=plugins=grpc:golang serving/tensorflow_serving/*/*.proto
protoc -I=serving -I=tensorflow --go_out=plugins=grpc:golang serving/tensorflow_serving/sources/storage_path/*.proto
protoc -I=serving -I=tensorflow --go_out=plugins=grpc:golang tensorflow/tensorflow/core/framework/*.proto
protoc -I=serving -I=tensorflow --go_out=plugins=grpc:golang tensorflow/tensorflow/core/example/*.proto
protoc -I=serving -I=tensorflow --go_out=plugins=grpc:golang tensorflow/tensorflow/core/protobuf/*.proto
protoc -I=serving -I=tensorflow --go_out=plugins=grpc:golang tensorflow/tensorflow/stream_executor/*.proto其中
-I
指定了proto
文件搜索依赖文件的路径,可以指定多次。--go_out
指定了保存go
文件的目录(这里为golang
)以及使用的grpc
插件。命令的最后一项以通配符的形式指定了proto
文件的输入位置。至于为何选择上述的
proto
文件进行API
生成,是根据实际使用情况以及proto
文件之间的依赖关系决定的。可以先从serving
的proto
源码入手,并参照其Python GRPC
示例的代码实现,找到入口的proto
文件,然后根据其本身及依赖的proto
文件生成相应的API
代码文件,接着进行编码测试,查缺补漏,直到所有代码文件编译无误为止。执行完上述命令后会在
golang
目录下生成两个目录,分别为github.com
和tensorflow_serving
,前者包含有从tensorflow
源码中的proto
文件生成的go
文件,后者包含从serving
源码中的proto
文件生成的go
文件。因为tensorflow
源码中的proto
文件均包含go_package
选项如option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_go_proto";
,它们指定了生成的go
文件的输出目录,所以其生成的go
文件会在github.com
目录下,而serving
源码中的proto
文件不包含该选项,所以go
文件的输出目录默认与源码文件的目录相同。tensorflow_serving
目录下生成的go
文件中可能会出现循环引用错误,如下所示:import cycle not allowed
package github.com/alex/tensorflow-serving-api-go
imports tensorflow_serving/apis
imports tensorflow_serving/core
imports tensorflow_serving/apis此时你需要将
tensorflow_serving/core
目录下logging.pb.go
文件和tensorflow_serving/apis
目录下的prediction_log.pb.go
文件删除以解决上述问题。删除上述代码文件并不影响后续的GRPC
模型预测请求。假设我有一个名为
first_model
的模型部署在了TensorFlow Serving
服务上,它的元数据信息如下所示:curl http://localhost:8501/v1/models/first_model/versions/0/metadata
{
"model_spec": {
"name": "first_model",
"signature_name": "",
"version": "0"
},
"metadata": {
"signature_def": {
"signature_def": {
"serving_default": {
"inputs": {
"input_1": {
"dtype": "DT_INT64",
"tensor_shape": {
"dim": [
{
"size": "-1",
"name": ""
},
{
"size": "31",
"name": ""
}
],
"unknown_rank": false
},
"name": "serving_default_input_1:0"
}
},
"outputs": {
"output_1": {
"dtype": "DT_FLOAT",
"tensor_shape": {
"dim": [
{
"size": "-1",
"name": ""
},
{
"size": "1",
"name": ""
}
],
"unknown_rank": false
},
"name": "StatefulPartitionedCall:0"
}
},
"method_name": "tensorflow/serving/predict"
},
"__saved_model_init_op": {
"inputs": {},
"outputs": {
"__saved_model_init_op": {
"dtype": "DT_INVALID",
"tensor_shape": {
"dim": [],
"unknown_rank": true
},
"name": "NoOp"
}
},
"method_name": ""
}
}
}
}
}我们需要重点关注上述信息中的
inputs
选项,它定义了该模型输入数据的key
值(这里为input_1
) 、输入数据的维度(这里为(-1, 31)
)以及输入数据的类型 (这里为DT_INT64
)。在进行GRPC
预测请求时,代码中指定的输入数据需要与元数据中定义的各种输入信息相匹配,否则就无法获取正确的模型输出。创建
go
项目并向first_model
发送GRPC
预测请求。具体的项目详情请参见Github
上的实现及说明,这里只列出主函数的代码,如下所示:package main
import (
"context"
"log"
apis "tensorflow_serving/apis"
"time"
"github.com/golang/protobuf/ptypes/wrappers"
"github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_go_proto"
"github.com/tensorflow/tensorflow/tensorflow/go/core/framework/tensor_shape_go_proto"
"github.com/tensorflow/tensorflow/tensorflow/go/core/framework/types_go_proto"
"google.golang.org/grpc"
)
var (
// TensorFlow serving grpc address.
address = "127.0.0.1:8500"
)
func main() {
// Create a grpc request.
request := &apis.PredictRequest{
ModelSpec: &apis.ModelSpec{},
Inputs: make(map[string]*tensor_go_proto.TensorProto),
}
request.ModelSpec.Name = "first_model"
request.ModelSpec.SignatureName = "serving_default"
// request.ModelSpec.VersionChoice = &apis.ModelSpec_VersionLabel{VersionLabel: "stable"}
request.ModelSpec.VersionChoice = &apis.ModelSpec_Version{Version: &wrappers.Int64Value{Value: 0}}
request.Inputs["input_1"] = &tensor_go_proto.TensorProto{
Dtype: types_go_proto.DataType_DT_INT64,
TensorShape: &tensor_shape_go_proto.TensorShapeProto{
Dim: []*tensor_shape_go_proto.TensorShapeProto_Dim{
&tensor_shape_go_proto.TensorShapeProto_Dim{
Size: int64(2),
},
&tensor_shape_go_proto.TensorShapeProto_Dim{
Size: int64(31),
},
},
},
Int64Val: []int64{
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
},
}
// Create a grpc connection.
conn, err := grpc.Dial(address, grpc.WithInsecure(), grpc.WithBlock(), grpc.WithTimeout(10*time.Second))
if err != nil {
log.Fatalf("couldn't connect: %s", err.Error())
}
defer conn.Close()
// Wrap the grpc uri with client.
client := apis.NewPredictionServiceClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
// Send the grpc request.
response, err := client.Predict(ctx, request)
if err != nil {
log.Fatalf("couldn't get response: %v", err)
}
log.Printf("%+v", response)
}Github
地址为:https://github.com/AlexanderJLiu/tensorflow-serving-api/tree/master/golang
Java
在生成 java
相关的 API
代码文件时,我们需要安装 java
环境以及一些 protoc
插件来辅助我们进行文件生成操作。以下相关操作均在 Linux
系统完成。
安装
OpenJDK
。centos
yum-config-manager --enable rhel-7-server-optional-rpms
yum install java-11-openjdk-devel
ubuntu
apt-get install openjdk-11-jdk
test
java -version安装
protoc-gen-grpc-java
,用于生成grpc java
文件。wget https://repo1.maven.org/maven2/io/grpc/protoc-gen-grpc-java/1.30.2/protoc-gen-grpc-java-1.30.2-linux-x86_64.exe
mv protoc-gen-grpc-java-1.30.2-linux-x86_64.exe /usr/local/protoc/bin/protoc-gen-grpc-java将源码切换到指定分支或标签。
cd tensorflow-serving-api/tensorflow
git checkout tags/v2.2.0
cd tensorflow-serving-api/serving
git checkout tags/2.2.0使用
protoc
工具生成java
文件。cd tensorflow-serving-api
protoc -I=serving -I=tensorflow --plugin=/usr/local/protoc/bin/protoc-gen-grpc-java --grpc-java_out=java --java_out=java serving/tensorflow_serving/*/*.proto
protoc -I=serving -I=tensorflow --plugin=/usr/local/protoc/bin/protoc-gen-grpc-java --grpc-java_out=java --java_out=java serving/tensorflow_serving/sources/storage_path/*.proto其中
-I
指定了proto
文件搜索依赖文件的路径,可以指定多次。--plugin
指定了要使用的grpc
插件的路径。--grpc-java_out
指定了grpc java
文件的保存目录。--java_out
指定了java
文件的保存目录。命令的最后一项以通配符的形式指定了proto
文件的输入位置。由于
TensorFlow
官方已经基于proto
文件生成了TensorFlow
的java
文件,因此我们就没有必要自己生成了,在使用时直接从maven
仓库引入即可:implementation("org.tensorflow:proto:1.15.0")
。创建
java
项目并向first_model
发送GRPC
预测请求。具体的项目详情请参见Github
上的实现及说明,这里只列出主函数的代码,如下所示:package com.github.alex;
import java.util.Arrays;
import java.util.concurrent.TimeUnit;
import com.google.protobuf.Int64Value;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;
import org.tensorflow.framework.TensorShapeProto.Dim;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.stub.StreamObserver;
import tensorflow.serving.Model.ModelSpec;
import tensorflow.serving.Predict.PredictRequest;
import tensorflow.serving.Predict.PredictResponse;
import tensorflow.serving.PredictionServiceGrpc;
import tensorflow.serving.PredictionServiceGrpc.PredictionServiceBlockingStub;
import tensorflow.serving.PredictionServiceGrpc.PredictionServiceStub;
public class App {
public String getGreeting() {
return "Hello world.";
}
public static void main(String[] args) {
PredictRequest.Builder requestBuilder = PredictRequest.newBuilder();
ModelSpec.Builder modelSpecBuilder = ModelSpec.newBuilder();
modelSpecBuilder.setSignatureName("serving_default");
modelSpecBuilder.setName("first_model");
modelSpecBuilder.setVersion(Int64Value.newBuilder().setValue(0L));
requestBuilder.setModelSpec(modelSpecBuilder.build());
TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
tensorProtoBuilder.setDtype(DataType.DT_INT64);
Dim[] dim = {Dim.newBuilder().setSize(1).build(), Dim.newBuilder().setSize(31).build()};
tensorProtoBuilder
.setTensorShape(TensorShapeProto.newBuilder().addAllDim(Arrays.asList(dim)));
// tensorProtoBuilder.setTensorShape(TensorShapeProto.newBuilder()
// .addDim(Dim.newBuilder().setSize(1)).addDim(Dim.newBuilder().setSize(31)));
Long[] inputs = {1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L,
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L};
tensorProtoBuilder.addAllInt64Val(Arrays.asList(inputs));
requestBuilder.putInputs("input_1", tensorProtoBuilder.build());
PredictRequest request = requestBuilder.build();
System.out.println(request);
String target = "127.0.0.1:8500";
// Create a communication channel to the server, known as a Channel. Channels are
// thread-safe and reusable. It is common to create channels at the beginning of your
// application and reuse them until the application shuts down.
ManagedChannel channel = ManagedChannelBuilder.forTarget(target)
// Channels are secure by default (via SSL/TLS). For the example we disable TLS to
// avoid needing certificates.
.usePlaintext().build();
PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel);
try {
PredictResponse response = stub.predict(request);
System.out.println(response);
channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}Github
地址为:https://github.com/AlexanderJLiu/tensorflow-serving-api/tree/master/java
其他语言
我在 Github
上创建了一个名为 tensorflow-serving-api
项目,该项目旨在生成 Protocol Buffer
和 GRPC
所支持的所有语言的 TensorFlow Serving API
文件,并且以完整项目的形式给出使用的示例。
该项目正在逐渐完善中,目前已经实现了 Golang
, Java
以及 Python
语言的 TensorFlow Serving API
和项目示例,后续还会加入更多的语言实现,也欢迎大家来共同参与贡献。
项目链接地址:https://github.com/AlexanderJLiu/tensorflow-serving-api
参考资料
- Go Installation Instructions
- Protocol Buffer Basics: Go
- ProtoBuf: Go Generated Code
- Go GRPC Examples
- Install OpenJDK on Windows and Linux
- Protocol Buffer Basics: Java
- ProtoBuf: Java Generated Code
- GRPC: Java Generated-code Reference