Post

(TensorFlow) Datasets API

자세한 설명은 가이드 참조.

Dataset API 를 사용하면 input pipelines / Threading and Queue 과정을 손쉽게 처리할 수 있다. ( 1.4부터 contrib에서 코어로 옮겨졌다. )

datasetelement들로 이루어져 있으며 elementtf.Tensor들로 이루어져 있다.

  • element1
    • image1
    • label1
  • element2
    • image2
    • label2

Basic mechanics

1
2
3
4
5
6
7
8
9
>>> sess.run(tf.random\_uniform([2, 4]))
array([[ 0.77109301,  0.34201586,  0.0554806 ,  0.96262276],
[ 0.99343991,  0.84189892,  0.8897506 ,  0.27429628]], dtype=float32)
>>> dataset1 = tf.data.Dataset.from\_tensor\_slices(tf.random\_uniform([2, 4]))
>>> dataset1.output\_types
tf.float32
>>> dataset1.output\_shapes
TensorShape([Dimension(4)])

한 행이 한 element라고 생각하면 된다. 한 row 씩 반환하게 되고, 한 row는 한 element이기 때문에 output\_shapeselement의 shape이다.

dataset에서 element를 꺼내기 위해서는 tf.data.Iterator를 사용한다.

1
2
3
4
5
6
7
8
9
>>> dataset = tf.data.Dataset.range(10)
>>> iterator = dataset.make\_one\_shot\_iterator()
>>> next\_element = iterator.get\_next()
>>> sess = tf.Session()
>>> sess.run(next\_element)
0
>>> sess.run(next\_element)
1

Note: Currently, one-shot iterators are the only type that is easily usable with an Estimator. (17.11.17)

그러나 다음과 같은 parameterization은 지원하지 않는다.

1
2
3
4
5
>>> max\_value = tf.placeholder(tf.int64, shape=[])
>>> dataset = tf.data.Dataset.range(max\_value)
>>> iterator = dataset.make\_one\_shot\_iterator()
ValueError: Cannot capture a placeholder (name:Placeholder\_1, type:Placeholder) by value.

그래서 이와같이 사용하려면 dataset.make\_initializable\_iterator() 등등을 사용해야 한다. 가장 유연하고 좋은 방식은 feedable iterator를 사용하는 것인데, 위에 적어놓았듯 Estimator와 같이 쓰기가 좀 그렇다.

그리고 주의해야 할 점은, Iterator.get\_next()가 호출될 때 마다 Iterator에서 새로운 element를 꺼내는 것이 아니라는 점이다. sess.run()에 집어넣어야 다음 element를 반환한다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
>>> dataset = tf.data.Dataset.range(5)
>>> iterator = dataset.make\_one\_shot\_iterator()
>>> next\_element = iterator.get\_next()
>>> # 여기서 next\_element가 두 번 호출된다고 두 번 꺼내는게 아니다.
>>> result = tf.add(next\_element, next\_element)
>>> sess.run(result)
0
>>> sess.run(result)
2    # 1+1 = 2
>>> sess.run(next\_element)
2    # 2
>>> sess.run(result)
6    # 3+3 = 2

그래서 다음과 같이 사용한다.

1
2
3
4
5
6
7
sess.run(iterator.initializer)
while True:
try:
sess.run(result)
except tf.errors.OutOfRangeError:
break

Decoding image data and resizing it

이런 식으로 사용한다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
## Reads an image from a file, decodes it into a dense tensor, and resizes it
## to a fixed shape.
def \_image\_processing(filename, label):
image\_string = tf.read\_file(filename)
image\_decoded = tf.image.decode\_png(image\_string)
## tf.image.decode\_image를 사용하면 더 좋지만 에러 발생
image\_resized = tf.image.resize\_images(image\_decoded, [28, 28])
return image\_resized, label
## A vector of filenames.

  
fnames = tf.constant(glob.glob("./mnist\_test/\*"))
labels = tf.constant([2])
dataset = tf.data.Dataset.from\_tensor\_slices((fnames, labels))
dataset = dataset.map(\_image\_processing)

Applying arbitrary Python logic with tf.py_func()

텐서플로우의 동작 방식이 일반적인 python logic과는 달리 그래프를 구성하고, 나중에 실행하는 방식이다 보니 원래대로라면 OpenCV같은 다른 API의 파일 처리와 연계하기가 조금 복잡스럽다. 그러나 이를 간단히 처리할 수 있도록 tf.py\_func()라는 API를 지원한다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import cv2

  

## Use a custom OpenCV function to read the image, instead of the standard
## TensorFlow `tf.read\_file()` operation.
def \_read\_py\_function(filename, label):
image\_decoded = cv2.imread(image\_string, cv2.IMREAD\_GRAYSCALE)
return image\_decoded, label

  

## Use standard TensorFlow operations to resize the image to a fixed shape.
def \_resize\_function(image\_decoded, label):
image\_decoded.set\_shape([None, None, None])
image\_resized = tf.image.resize\_images(image\_decoded, [28, 28])
return image\_resized, label

  

filenames = ["/var/data/image1.jpg", "/var/data/image2.jpg", ...]
labels = [0, 37, 29, 1, ...]

  

dataset = tf.data.Dataset.from\_tensor\_slices((filenames, labels))
dataset = dataset.map(
lambda filename, label: tuple(tf.py\_func(
\_read\_py\_function, [filename, label], [tf.uint8, label.dtype])))
dataset = dataset.map(\_resize\_function)

shuffle, epoch, batch

1
2
3
4
5
6
>>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.map(...)
>>> dataset = dataset.shuffle(buffer\_size = 10)
>>> dataset = dataset.batch(2)     # batch
>>> dataset = dataset.repeat(3)    # epoch. 지정하지 않으면 무한히 제공.

This post is licensed under CC BY 4.0 by the author.