forked from databricks/tensorframes
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhw_sparkdl.sclpt
More file actions
123 lines (101 loc) · 3.93 KB
/
hw_sparkdl.sclpt
File metadata and controls
123 lines (101 loc) · 3.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
// -*- scala -*-
// Must add the spark package bintray so that we can import packages
import coursier.maven.MavenRepository
interp.repositories() ++= Seq(
MavenRepository("https://dl.bintray.com/spark-packages/maven/")
)
// Load all packages into the session
// REF: http://www.lihaoyi.com/Ammonite/#ImportedScriptsareRe-used
import java.util.UUID
import java.util.concurrent.atomic.AtomicBoolean
import scala.util.Random
// Load Spark settings
import $exec.DevSparkEnv, DevSparkEnv._
import $exec.DevDataSet, DevDataSet._
import spark.implicits._
import org.apache.spark.sql.types._
/** TensorFlow Java */
import $exec.DevTensorFlow, DevTensorFlow._
/** Datatype for image column */
val imageStruct = StructType(Seq(
StructField("mode", StringType, nullable = false),
StructField("height", IntegerType, nullable = false),
StructField("width", IntegerType, nullable = false),
StructField("nChannels", IntegerType, nullable = false),
StructField("data", BinaryType, nullable = false)
))
/** Load image dataframe */
val dfImg = ImageDataSrc.load("car_images") // from Tim's demo
dfImg.select("image.nChannels").distinct.collect
/** Load image classification model */
val fpModelRoot = FPath.home / "local" / "data" / "tf_model"
val fpIv3 = fpModelRoot / "inception-v3" / "main.pb"
val fpIv3Preproc = fpModelRoot / "inception-v3" / "preprocessor.pb"
// The high level API, enough to for building graphs
val gfnMain = GraphFunction(
fpModelRoot / "inception-v3" / "main.pb",
inputNames = Seq("image_input"),
outputNames = Seq("prediction_vector")
)
val gfnPreproc = GraphFunction(
fpModelRoot / "inception-v3" / "preprocessor.pb",
inputNames = Seq("image_input"),
outputNames = Seq("output")
)
val builder = GraphBuilderSession()
builder.importGraphFunction(gfnPreproc, Some("preproc"))
builder.importGraphFunction(gfnMain, Some("inception"))
builder.showOp("preproc/image_input")
builder.showOp("inception/image_input")
builder.showOp("inception/prediction_vector")
// The lower level API, working directly on tf.Graph nodes and protobuf
// Operations seem to be slower as we pull in all the node info
// Read the bytes of the model
val bytes = Files.readAllBytes(fpIv3)
// Using the protocol buffer object, and it is somewhat slow
val gdef = GraphDef.parseFrom(bytes)
val nodes = gdef.getNodeList.asScala
// All the placeholders must be filled
val placeholders = nodes.filter { _.getOp == "Placeholder" }
// We constructed the graph, where last element is the reflection of the output
// Fields: name, op, input, attr
val n = nodes.last
n.getName
val inputNodes = nodes.flatMap { _.getInputList.asScala }.toSet
val outputNodes = nodes.map { _.getName }.toSet // all nodes are output nodes
println(s"nodes without out-bound edge: ${outputNodes -- inputNodes}")
n.getInputList.asScala
println(s"TF Graph: #nodes ${nodes.size}")
// Now execute some operation
val shape = builder.op("preproc/image_input").output(0).shape
val tnsrIn = tf.Tensor.create(
Array.fill[Float](shape.size(0).toInt)(Random.nextGaussian.toFloat)
)
val tnsrIn = tf.Tensor.create(
Array.ofDim[Byte](shape.size(0).toInt)
)
val fpImg = FPath.home / "local" / "data" / "images" / "cat.jpg"
val tnsrImg = tf.Tensor.create(Files.readAllBytes(fpImg))
// `Runner` has various input types for `feed`
// https://www.tensorflow.org/api_docs/java/reference/org/tensorflow/Session.Runner
val fetches = {
val runner = builder.sess.runner()
val tnsrInterm = runner
.feed("preproc/image_input", tnsrImg)
.fetch("preproc/output")
.run().get(0)
runner
.feed("inception/image_input", tnsrInterm)
.fetch("inception/prediction_vector")
.run().get(1)
}
val tnsrOut = fetches
println(tnsrOut.numElements)
// Create a buffer and export the result
val buff = java.nio.FloatBuffer.wrap(
Array.ofDim[Float](tnsrOut.shape.head.toInt)
)
tnsrOut.writeTo(buff)
val ss = buff.array
// Inception input from image bytes
//val input = g.opBuilder("Placeholder", "input").setAttr("dtype", ).build.output(0)