Skip to content

Commit af18cbf

Browse files
author
wangguowei33
committed
fix lr in text_train_script.py
1 parent 82d9153 commit af18cbf

File tree

8 files changed

+2262
-2259
lines changed

8 files changed

+2262
-2259
lines changed
Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,39 @@
1-
import tensorflow as tf
2-
from keras_cv_attention_models.imagenet.tf_data import init_mean_std_by_rescale_mode, tf_imread, random_crop_and_resize_image, build_custom_dataset
3-
4-
5-
def image_process(image, image_size=(224, 224), is_train=True):
6-
image = tf_imread(image)
7-
if is_train:
8-
image = random_crop_and_resize_image(image, image_size, scale=(0.9, 1.0), method="bicubic", antialias=True)[0]
9-
else:
10-
image = tf.image.resize(image, image_size, method="bicubic", antialias=True)
11-
image = tf.cast(image, tf.float32)
12-
image.set_shape([*image_size, 3])
13-
return image
14-
15-
16-
def init_dataset(data_path, caption_tokenizer, batch_size=64, image_size=224, rescale_mode="torch"):
17-
dataset, total_images, num_classes, num_channels = build_custom_dataset(data_path, with_info=True, caption_tokenizer=caption_tokenizer)
18-
19-
mean, std = init_mean_std_by_rescale_mode(rescale_mode)
20-
image_size = image_size if isinstance(image_size, (list, tuple)) else [image_size, image_size]
21-
22-
AUTOTUNE, buffer_size, seed = tf.data.AUTOTUNE, batch_size * 100, None
23-
train_pre_batch = lambda data_point: (image_process(data_point["image"], image_size, is_train=True), data_point["caption"])
24-
y_true = tf.range(batch_size)
25-
train_post_batch = lambda xx, caption: (((xx - mean) / std, caption), y_true)
26-
27-
train_dataset = dataset["train"]
28-
train_dataset = train_dataset.shuffle(buffer_size, seed=seed).map(train_pre_batch, num_parallel_calls=AUTOTUNE)
29-
train_dataset = train_dataset.batch(batch_size, drop_remainder=True).map(train_post_batch, num_parallel_calls=AUTOTUNE)
30-
train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
31-
32-
test_dataset = dataset.get("validation", dataset.get("test", None))
33-
if test_dataset is not None:
34-
test_pre_batch = lambda data_point: (image_process(data_point["image"], image_size, is_train=False), data_point["caption"])
35-
test_dataset = test_dataset.map(test_pre_batch, num_parallel_calls=AUTOTUNE)
36-
test_dataset = test_dataset.batch(batch_size, drop_remainder=True).map(train_post_batch)
37-
38-
return train_dataset, test_dataset
1+
Unit test code for this:
2+
import tensorflow as tf
3+
from keras_cv_attention_models.imagenet.tf_data import init_mean_std_by_rescale_mode, tf_imread, random_crop_and_resize_image, build_custom_dataset
4+
5+
6+
def image_process(image, image_size=(224, 224), is_train=True):
7+
image = tf_imread(image)
8+
if is_train:
9+
image = random_crop_and_resize_image(image, image_size, scale=(0.9, 1.0), method="bicubic", antialias=True)[0]
10+
else:
11+
image = tf.image.resize(image, image_size, method="bicubic", antialias=True)
12+
image = tf.cast(image, tf.float32)
13+
image.set_shape([*image_size, 3])
14+
return image
15+
16+
17+
def init_dataset(data_path, caption_tokenizer, batch_size=64, image_size=224, rescale_mode="torch"):
18+
dataset, total_images, num_classes, num_channels = build_custom_dataset(data_path, with_info=True, caption_tokenizer=caption_tokenizer)
19+
20+
mean, std = init_mean_std_by_rescale_mode(rescale_mode)
21+
image_size = image_size if isinstance(image_size, (list, tuple)) else [image_size, image_size]
22+
23+
AUTOTUNE, buffer_size, seed = tf.data.AUTOTUNE, batch_size * 100, None
24+
train_pre_batch = lambda data_point: (image_process(data_point["image"], image_size, is_train=True), data_point["caption"])
25+
y_true = tf.range(batch_size)
26+
train_post_batch = lambda xx, caption: (((xx - mean) / std, caption), y_true)
27+
28+
train_dataset = dataset["train"]
29+
train_dataset = train_dataset.shuffle(buffer_size, seed=seed).map(train_pre_batch, num_parallel_calls=AUTOTUNE)
30+
train_dataset = train_dataset.batch(batch_size, drop_remainder=True).map(train_post_batch, num_parallel_calls=AUTOTUNE)
31+
train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
32+
33+
test_dataset = dataset.get("validation", dataset.get("test", None))
34+
if test_dataset is not None:
35+
test_pre_batch = lambda data_point: (image_process(data_point["image"], image_size, is_train=False), data_point["caption"])
36+
test_dataset = test_dataset.map(test_pre_batch, num_parallel_calls=AUTOTUNE)
37+
test_dataset = test_dataset.batch(batch_size, drop_remainder=True).map(train_post_batch)
38+
39+
return train_dataset, test_dataset

keras_cv_attention_models/fastvit/fastvit.py

Lines changed: 242 additions & 242 deletions
Large diffs are not rendered by default.

keras_cv_attention_models/gpvit/gpvit.py

Lines changed: 267 additions & 267 deletions
Large diffs are not rendered by default.

keras_cv_attention_models/hornet/hornet.py

Lines changed: 227 additions & 227 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)