|
1 |
| -import tensorflow as tf |
2 |
| -from keras_cv_attention_models.imagenet.tf_data import init_mean_std_by_rescale_mode, tf_imread, random_crop_and_resize_image, build_custom_dataset |
3 |
| - |
4 |
| - |
5 |
| -def image_process(image, image_size=(224, 224), is_train=True): |
6 |
| - image = tf_imread(image) |
7 |
| - if is_train: |
8 |
| - image = random_crop_and_resize_image(image, image_size, scale=(0.9, 1.0), method="bicubic", antialias=True)[0] |
9 |
| - else: |
10 |
| - image = tf.image.resize(image, image_size, method="bicubic", antialias=True) |
11 |
| - image = tf.cast(image, tf.float32) |
12 |
| - image.set_shape([*image_size, 3]) |
13 |
| - return image |
14 |
| - |
15 |
| - |
16 |
| -def init_dataset(data_path, caption_tokenizer, batch_size=64, image_size=224, rescale_mode="torch"): |
17 |
| - dataset, total_images, num_classes, num_channels = build_custom_dataset(data_path, with_info=True, caption_tokenizer=caption_tokenizer) |
18 |
| - |
19 |
| - mean, std = init_mean_std_by_rescale_mode(rescale_mode) |
20 |
| - image_size = image_size if isinstance(image_size, (list, tuple)) else [image_size, image_size] |
21 |
| - |
22 |
| - AUTOTUNE, buffer_size, seed = tf.data.AUTOTUNE, batch_size * 100, None |
23 |
| - train_pre_batch = lambda data_point: (image_process(data_point["image"], image_size, is_train=True), data_point["caption"]) |
24 |
| - y_true = tf.range(batch_size) |
25 |
| - train_post_batch = lambda xx, caption: (((xx - mean) / std, caption), y_true) |
26 |
| - |
27 |
| - train_dataset = dataset["train"] |
28 |
| - train_dataset = train_dataset.shuffle(buffer_size, seed=seed).map(train_pre_batch, num_parallel_calls=AUTOTUNE) |
29 |
| - train_dataset = train_dataset.batch(batch_size, drop_remainder=True).map(train_post_batch, num_parallel_calls=AUTOTUNE) |
30 |
| - train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE) |
31 |
| - |
32 |
| - test_dataset = dataset.get("validation", dataset.get("test", None)) |
33 |
| - if test_dataset is not None: |
34 |
| - test_pre_batch = lambda data_point: (image_process(data_point["image"], image_size, is_train=False), data_point["caption"]) |
35 |
| - test_dataset = test_dataset.map(test_pre_batch, num_parallel_calls=AUTOTUNE) |
36 |
| - test_dataset = test_dataset.batch(batch_size, drop_remainder=True).map(train_post_batch) |
37 |
| - |
38 |
| - return train_dataset, test_dataset |
| 1 | +Unit test code for this: |
| 2 | +import tensorflow as tf |
| 3 | +from keras_cv_attention_models.imagenet.tf_data import init_mean_std_by_rescale_mode, tf_imread, random_crop_and_resize_image, build_custom_dataset |
| 4 | + |
| 5 | + |
| 6 | +def image_process(image, image_size=(224, 224), is_train=True): |
| 7 | + image = tf_imread(image) |
| 8 | + if is_train: |
| 9 | + image = random_crop_and_resize_image(image, image_size, scale=(0.9, 1.0), method="bicubic", antialias=True)[0] |
| 10 | + else: |
| 11 | + image = tf.image.resize(image, image_size, method="bicubic", antialias=True) |
| 12 | + image = tf.cast(image, tf.float32) |
| 13 | + image.set_shape([*image_size, 3]) |
| 14 | + return image |
| 15 | + |
| 16 | + |
| 17 | +def init_dataset(data_path, caption_tokenizer, batch_size=64, image_size=224, rescale_mode="torch"): |
| 18 | + dataset, total_images, num_classes, num_channels = build_custom_dataset(data_path, with_info=True, caption_tokenizer=caption_tokenizer) |
| 19 | + |
| 20 | + mean, std = init_mean_std_by_rescale_mode(rescale_mode) |
| 21 | + image_size = image_size if isinstance(image_size, (list, tuple)) else [image_size, image_size] |
| 22 | + |
| 23 | + AUTOTUNE, buffer_size, seed = tf.data.AUTOTUNE, batch_size * 100, None |
| 24 | + train_pre_batch = lambda data_point: (image_process(data_point["image"], image_size, is_train=True), data_point["caption"]) |
| 25 | + y_true = tf.range(batch_size) |
| 26 | + train_post_batch = lambda xx, caption: (((xx - mean) / std, caption), y_true) |
| 27 | + |
| 28 | + train_dataset = dataset["train"] |
| 29 | + train_dataset = train_dataset.shuffle(buffer_size, seed=seed).map(train_pre_batch, num_parallel_calls=AUTOTUNE) |
| 30 | + train_dataset = train_dataset.batch(batch_size, drop_remainder=True).map(train_post_batch, num_parallel_calls=AUTOTUNE) |
| 31 | + train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE) |
| 32 | + |
| 33 | + test_dataset = dataset.get("validation", dataset.get("test", None)) |
| 34 | + if test_dataset is not None: |
| 35 | + test_pre_batch = lambda data_point: (image_process(data_point["image"], image_size, is_train=False), data_point["caption"]) |
| 36 | + test_dataset = test_dataset.map(test_pre_batch, num_parallel_calls=AUTOTUNE) |
| 37 | + test_dataset = test_dataset.batch(batch_size, drop_remainder=True).map(train_post_batch) |
| 38 | + |
| 39 | + return train_dataset, test_dataset |
0 commit comments