Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit 3c39017

Browse files
authored
Merge pull request #1 from 1025KB/1025KB-patch-2-1
1025 kb patch 2 1
2 parents 57e8fcb + 44159ab commit 3c39017

File tree

5 files changed

+639
-0
lines changed

5 files changed

+639
-0
lines changed

rfcs/20200117-tfx-generic-trainer.md

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
# TFX Generic Trainer
2+
3+
| Status | Proposed |
4+
| :------------ | :-------------------------------------------------------- |
5+
| **Author(s)** | Jiayi Zhao ([email protected]) |
6+
| **Sponsor** | Konstantinos Katsiapis ([email protected]), Zhitao Li ([email protected]), Karmel Allison ([email protected]) |
7+
| **Updated** | 2020-01-17 |
8+
9+
## Objective
10+
11+
### Goal
12+
13+
* Support any TensorFlow Training loop in TFX Trainer in addition to
14+
tf.estimator, primarily focused on native Keras model.
15+
16+
### Non Goal
17+
18+
* Natively support multi-worker distributed training by the system.
19+
* Non-TF training that generates savedmodel.
20+
21+
## Background and Motivation
22+
23+
In current TFX Trainer component, only tf.estimator is supported for training
24+
and generating models. User provides a module file which contains a
25+
`trainer_fn`, trainer will call the function to get the estimator model and
26+
related spec for training, and generate a saved model by
27+
`tf.estimator.train_and_evaluate`.
28+
29+
[tf.keras](https://www.tensorflow.org/guide/keras) is TensorFlow's high-level
30+
API for building and training models. It’s currently supported in TFX by using
31+
`tf.keras.estimator.model_to_estimator` in module file. User can create keras
32+
model in their `trainer_fn` but need to convert it to estimator for return (for
33+
example,
34+
[cifar10](https://github.com/tensorflow/tfx/blob/r0.15/tfx/examples/cifar10/cifar10_utils.py)).
35+
36+
This doc will focus on native Keras support (without model_to_estimator) in TFX.
37+
We propose changing the user facing API to be more generic so that users can do
38+
(single node) native Keras model training within TFX.
39+
40+
## User Benefit
41+
42+
* Allows non estimator based training, especially Keras as TensorFlow is
43+
establishing Keras as the
44+
[Standardized high-level API](https://medium.com/tensorflow/standardizing-on-keras-guidance-on-high-level-apis-in-tensorflow-2-0-bad2b04c819a).
45+
* Allows
46+
[custom training](https://www.tensorflow.org/tutorials/customization/custom_training)
47+
for customization of training loop.
48+
49+
## Detailed Design
50+
51+
Below shows the pseudo code for current TFX Trainer’s executor:
52+
53+
```python
54+
class Executor(base_executor.BaseExecutor):
55+
56+
def Do(self, input_dict: Dict[Text, List[types.Artifact]],
57+
output_dict: Dict[Text, List[types.Artifact]],
58+
exec_properties: Dict[Text, Any]) -> None:
59+
"""Uses a user-supplied tf.estimator to train a tf model locally."""
60+
trainer_fn = self._GetFn(exec_properties) # load from module file
61+
trainer_fn_args = self._GetFnArgs(
62+
input_dict, output_dict, exec_properties)
63+
64+
training_spec = trainer_fn(trainer_fn_args)
65+
tf.estimator.train_and_evaluate(training_spec['estimator'], ...)
66+
# For TFMA (downstream evaluator and model validator component).
67+
tfma.export.export_eval_savedmodel(training_spec['estimator'], ...)
68+
```
69+
70+
And the user supplied module file contains a function called `trainer_fn` which
71+
returns an estimator:
72+
73+
```python
74+
def _build_keras_model() -> tf.keras.Model:
75+
model = keras.XXX
76+
model.compile(...)
77+
return model
78+
79+
def trainer_fn(
80+
trainer_fn_args: trainer.executor.TrainerFnArgs) -> Dict[Text, Any]:
81+
"""Build the estimator using the high level API.
82+
83+
Args:
84+
trainer_fn_args: Holds args used to train the model as name/value pairs.
85+
86+
Returns:
87+
A dict of the following:
88+
- estimator: The estimator that will be used for training and eval.
89+
- train_spec: Spec for training.
90+
- eval_spec: Spec for eval.
91+
- eval_input_receiver_fn: Input function for eval.
92+
"""
93+
...
94+
95+
estimator = tf.keras.estimator.model_to_estimator(
96+
keras_model=_build_keras_model(), ...)
97+
98+
return {
99+
'estimator': estimator,
100+
'train_spec': ...,
101+
'eval_spec': ...,
102+
'eval_input_receiver_fn': ...
103+
}
104+
105+
```
106+
107+
We propose that in generic trainer's module file, user not only need to provide
108+
the model, but also control how the model is trained (`train_and_evaluate` for
109+
estimator and `model.fit` for keras will be in user module file instead of in
110+
executor), thus executor can be generic to model, and users can customize the
111+
[training loop](https://www.tensorflow.org/tutorials/customization/custom_training_walkthrough#training_loop).
112+
The executor pseudo code would look like below:
113+
114+
```python
115+
class Executor(base_executor.BaseExecutor):
116+
117+
def Do(self, input_dict: Dict[Text, List[types.Artifact]],
118+
output_dict: Dict[Text, List[types.Artifact]],
119+
exec_properties: Dict[Text, Any]) -> None:
120+
"""Train a user-supplied tf model."""
121+
run_fn = self._GetRunFn(exec_properties) # load from module file
122+
123+
# run_fn_args contains
124+
# 1. input train and eval data path.
125+
# 2. desired output model path for the trained savedmodel.
126+
# 3. training args, e.g., train/eval steps.
127+
# 4. optional base model.
128+
# 5. optional tuning result (kerastuner.HyperParameters config).
129+
# 6. optional custom config for passing params from component.
130+
run_fn_args = self._GetRunFnArgs(
131+
input_dict, output_dict, exec_properties)
132+
133+
run_fn(run_fn_args)
134+
# Validates the existence of run_fn's output savedmodel.
135+
...
136+
```
137+
138+
In module file, user needs to provide `run_fn` instead of previous `trainer_fn`.
139+
The `trainer_fn` was responsible for creating the model, in addition to that,
140+
`run_fn` also needs to handle training part and output the trained model to a
141+
desired location given by run args:
142+
143+
```python
144+
def run_fn(args: trainer.executor.TrainerFnArgs) -> None:
145+
"""Build the TF model and train it."""
146+
model = _build_keras_model()
147+
model.fit(...)
148+
# Save model to args.serving_model_dir.
149+
model.save(...)
150+
```
151+
152+
In generic trainer, executor is mainly for handling the
153+
[artifact](https://github.com/tensorflow/tfx/blob/r0.21/docs/guide/index.md#artifacts)
154+
(a unit of data that is passed between components), all model related logic is
155+
user supplied.
156+
157+
A separate GenericExecutor will be created, and the existing trainer executor
158+
will be sunsetted. We plan to keep estimator based executor for one more version
159+
and then deprecate it.
160+
161+
### How to convert current estimator based module file
162+
163+
To convert the current estimator based module file (e.g.,
164+
[iris](https://github.com/tensorflow/tfx/blob/r0.15/tfx/examples/iris/iris_utils.py))
165+
for generic trainer, simply add a run_fn that calls the trainer_fn and train the
166+
returned model (code that used to be in the trainer.executor.Do).
167+
168+
```python
169+
def run_fn(fn_args: executor.TrainerFnArgs):
170+
"""Train the model based on given args.
171+
172+
Args:
173+
fn_args: Holds args used to train the model as name/value pairs.
174+
"""
175+
schema = io_utils.parse_pbtxt_file(fn_args.schema_file, schema_pb2.Schema())
176+
177+
# Reuse the trainer_fn.
178+
training_spec = trainer_fn(fn_args, schema)
179+
180+
# Train the model
181+
absl.logging.info('Training model.')
182+
tf.estimator.train_and_evaluate(training_spec['estimator'],
183+
training_spec['train_spec'],
184+
training_spec['eval_spec'])
185+
absl.logging.info('Training complete. Model written to %s',
186+
fn_args.serving_model_dir)
187+
188+
# Export an eval savedmodel for TFMA, note that for keras, eval savedmodel is
189+
# not needed as TFMA2 can use serving model for evaluation.
190+
absl.logging.info('Exporting eval_savedmodel for TFMA.')
191+
tfma.export.export_eval_savedmodel(
192+
estimator=training_spec['estimator'],
193+
export_dir_base=fn_args.eval_model_dir,
194+
eval_input_receiver_fn=training_spec['eval_input_receiver_fn'])
195+
196+
absl.logging.info('Exported eval_savedmodel to %s.', fn_args.eval_model_dir)
197+
```
198+
199+
### tf.distribute.Strategy
200+
201+
Distribution strategy will be user module's responsibilty with the new generic
202+
trainer interface. To use it, user needs to modify the `run_fn()` in the module
203+
file, below shows the pseudo code example for single worker and multi-worker
204+
distribute strategy.
205+
206+
For single worker distribute strategy, you need to create an appropriate
207+
[tf.distribute.Strategy](https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy),
208+
and move the creation and compiling of Keras model inside `strategy.scope`:
209+
210+
```python
211+
def run_fn(args: trainer.executor.TrainerFnArgs) -> None:
212+
"""Build the TF model and train it."""
213+
mirrored_strategy = tf.distribute.MirroredStrategy()
214+
with mirrored_strategy.scope():
215+
model = _build_keras_model()
216+
model.fit(...)
217+
model.save(...)
218+
```
219+
220+
For multi-worker distribution strategy, the TFX Trainer does not have ability to
221+
spawn multi-worker cluster by
222+
[current executor](https://github.com/tensorflow/tfx/blob/r0.21/tfx/components/trainer/executor.py),
223+
hence not covered in the scope of this RFC. If the execution environment of an
224+
implementation of TFX Trainer has the ability to bring up the cluster of worker
225+
machines, and execute user funtion in the workers with correct
226+
[TF_CONFIG setup](https://www.tensorflow.org/guide/distributed_training#setting_up_tf_config_environment_variable),
227+
such as GCP AI Platform Training service via
228+
[extensions/google_cloud_ai_platform/trainer/executor.py](https://github.com/tensorflow/tfx/blob/r0.21/tfx/extensions/google_cloud_ai_platform/trainer/executor.py),
229+
the `run_fn()` would look like below:
230+
231+
```python
232+
def _is_chief() -> bool:
233+
"""Decide whether the current worker's role is chief."""
234+
# Check TF_CONFIG (set by TFX when bring up the worker) in execution env.
235+
...
236+
237+
def run_fn(args: trainer.executor.TrainerFnArgs) -> None:
238+
"""Build the TF model and train it."""
239+
ps_strategy = tf.distribute.experimental.ParameterServerStrategy()
240+
with ps_strategy.scope():
241+
model = _build_keras_model()
242+
model.fit(...)
243+
if _is_chief():
244+
model.save(...)
245+
```
246+
247+
For details about `tf.distribute.Strategy`, please refer to
248+
[here](https://www.tensorflow.org/guide/distributed_training).
249+
250+
## Future work
251+
252+
* Examples for custom training loop.
253+
* Native support for multi-worker distribution.

0 commit comments

Comments
 (0)