|
| 1 | +# TFX Generic Trainer |
| 2 | + |
| 3 | +| Status | Proposed | |
| 4 | +| :------------ | :-------------------------------------------------------- | |
| 5 | +| **Author(s) ** | Jiayi Zhao ( [email protected]) | |
| 6 | +| **Sponsor ** | Konstantinos Katsiapis ( [email protected]), Zhitao Li ( [email protected]), Karmel Allison ( [email protected]) | |
| 7 | +| **Updated** | 2020-01-17 | |
| 8 | + |
| 9 | +## Objective |
| 10 | + |
| 11 | +### Goal |
| 12 | + |
| 13 | +* Support any TensorFlow Training loop in TFX Trainer in addition to |
| 14 | + tf.estimator, primarily focused on native Keras model. |
| 15 | + |
| 16 | +### Non Goal |
| 17 | + |
| 18 | +* Natively support multi-worker distributed training by the system. |
| 19 | +* Non-TF training that generates savedmodel. |
| 20 | + |
| 21 | +## Background and Motivation |
| 22 | + |
| 23 | +In current TFX Trainer component, only tf.estimator is supported for training |
| 24 | +and generating models. User provides a module file which contains a |
| 25 | +`trainer_fn`, trainer will call the function to get the estimator model and |
| 26 | +related spec for training, and generate a saved model by |
| 27 | +`tf.estimator.train_and_evaluate`. |
| 28 | + |
| 29 | +[tf.keras](https://www.tensorflow.org/guide/keras) is TensorFlow's high-level |
| 30 | +API for building and training models. It’s currently supported in TFX by using |
| 31 | +`tf.keras.estimator.model_to_estimator` in module file. User can create keras |
| 32 | +model in their `trainer_fn` but need to convert it to estimator for return (for |
| 33 | +example, |
| 34 | +[cifar10](https://github.com/tensorflow/tfx/blob/r0.15/tfx/examples/cifar10/cifar10_utils.py)). |
| 35 | + |
| 36 | +This doc will focus on native Keras support (without model_to_estimator) in TFX. |
| 37 | +We propose changing the user facing API to be more generic so that users can do |
| 38 | +(single node) native Keras model training within TFX. |
| 39 | + |
| 40 | +## User Benefit |
| 41 | + |
| 42 | +* Allows non estimator based training, especially Keras as TensorFlow is |
| 43 | + establishing Keras as the |
| 44 | + [Standardized high-level API](https://medium.com/tensorflow/standardizing-on-keras-guidance-on-high-level-apis-in-tensorflow-2-0-bad2b04c819a). |
| 45 | +* Allows |
| 46 | + [custom training](https://www.tensorflow.org/tutorials/customization/custom_training) |
| 47 | + for customization of training loop. |
| 48 | + |
| 49 | +## Detailed Design |
| 50 | + |
| 51 | +Below shows the pseudo code for current TFX Trainer’s executor: |
| 52 | + |
| 53 | +```python |
| 54 | +class Executor(base_executor.BaseExecutor): |
| 55 | + |
| 56 | + def Do(self, input_dict: Dict[Text, List[types.Artifact]], |
| 57 | + output_dict: Dict[Text, List[types.Artifact]], |
| 58 | + exec_properties: Dict[Text, Any]) -> None: |
| 59 | + """Uses a user-supplied tf.estimator to train a tf model locally.""" |
| 60 | + trainer_fn = self._GetFn(exec_properties) # load from module file |
| 61 | + trainer_fn_args = self._GetFnArgs( |
| 62 | + input_dict, output_dict, exec_properties) |
| 63 | + |
| 64 | + training_spec = trainer_fn(trainer_fn_args) |
| 65 | + tf.estimator.train_and_evaluate(training_spec['estimator'], ...) |
| 66 | + # For TFMA (downstream evaluator and model validator component). |
| 67 | + tfma.export.export_eval_savedmodel(training_spec['estimator'], ...) |
| 68 | +``` |
| 69 | + |
| 70 | +And the user supplied module file contains a function called `trainer_fn` which |
| 71 | +returns an estimator: |
| 72 | + |
| 73 | +```python |
| 74 | +def _build_keras_model() -> tf.keras.Model: |
| 75 | + model = keras.XXX |
| 76 | + model.compile(...) |
| 77 | + return model |
| 78 | + |
| 79 | +def trainer_fn( |
| 80 | + trainer_fn_args: trainer.executor.TrainerFnArgs) -> Dict[Text, Any]: |
| 81 | + """Build the estimator using the high level API. |
| 82 | +
|
| 83 | + Args: |
| 84 | + trainer_fn_args: Holds args used to train the model as name/value pairs. |
| 85 | +
|
| 86 | + Returns: |
| 87 | + A dict of the following: |
| 88 | + - estimator: The estimator that will be used for training and eval. |
| 89 | + - train_spec: Spec for training. |
| 90 | + - eval_spec: Spec for eval. |
| 91 | + - eval_input_receiver_fn: Input function for eval. |
| 92 | + """ |
| 93 | + ... |
| 94 | + |
| 95 | + estimator = tf.keras.estimator.model_to_estimator( |
| 96 | + keras_model=_build_keras_model(), ...) |
| 97 | + |
| 98 | + return { |
| 99 | + 'estimator': estimator, |
| 100 | + 'train_spec': ..., |
| 101 | + 'eval_spec': ..., |
| 102 | + 'eval_input_receiver_fn': ... |
| 103 | + } |
| 104 | + |
| 105 | +``` |
| 106 | + |
| 107 | +We propose that in generic trainer's module file, user not only need to provide |
| 108 | +the model, but also control how the model is trained (`train_and_evaluate` for |
| 109 | +estimator and `model.fit` for keras will be in user module file instead of in |
| 110 | +executor), thus executor can be generic to model, and users can customize the |
| 111 | +[training loop](https://www.tensorflow.org/tutorials/customization/custom_training_walkthrough#training_loop). |
| 112 | +The executor pseudo code would look like below: |
| 113 | + |
| 114 | +```python |
| 115 | +class Executor(base_executor.BaseExecutor): |
| 116 | + |
| 117 | + def Do(self, input_dict: Dict[Text, List[types.Artifact]], |
| 118 | + output_dict: Dict[Text, List[types.Artifact]], |
| 119 | + exec_properties: Dict[Text, Any]) -> None: |
| 120 | + """Train a user-supplied tf model.""" |
| 121 | + run_fn = self._GetRunFn(exec_properties) # load from module file |
| 122 | + |
| 123 | + # run_fn_args contains |
| 124 | + # 1. input train and eval data path. |
| 125 | + # 2. desired output model path for the trained savedmodel. |
| 126 | + # 3. training args, e.g., train/eval steps. |
| 127 | + # 4. optional base model. |
| 128 | + # 5. optional tuning result (kerastuner.HyperParameters config). |
| 129 | + # 6. optional custom config for passing params from component. |
| 130 | + run_fn_args = self._GetRunFnArgs( |
| 131 | + input_dict, output_dict, exec_properties) |
| 132 | + |
| 133 | + run_fn(run_fn_args) |
| 134 | + # Validates the existence of run_fn's output savedmodel. |
| 135 | + ... |
| 136 | +``` |
| 137 | + |
| 138 | +In module file, user needs to provide `run_fn` instead of previous `trainer_fn`. |
| 139 | +The `trainer_fn` was responsible for creating the model, in addition to that, |
| 140 | +`run_fn` also needs to handle training part and output the trained model to a |
| 141 | +desired location given by run args: |
| 142 | + |
| 143 | +```python |
| 144 | +def run_fn(args: trainer.executor.TrainerFnArgs) -> None: |
| 145 | + """Build the TF model and train it.""" |
| 146 | + model = _build_keras_model() |
| 147 | + model.fit(...) |
| 148 | + # Save model to args.serving_model_dir. |
| 149 | + model.save(...) |
| 150 | +``` |
| 151 | + |
| 152 | +In generic trainer, executor is mainly for handling the |
| 153 | +[artifact](https://github.com/tensorflow/tfx/blob/r0.21/docs/guide/index.md#artifacts) |
| 154 | +(a unit of data that is passed between components), all model related logic is |
| 155 | +user supplied. |
| 156 | + |
| 157 | +A separate GenericExecutor will be created, and the existing trainer executor |
| 158 | +will be sunsetted. We plan to keep estimator based executor for one more version |
| 159 | +and then deprecate it. |
| 160 | + |
| 161 | +### How to convert current estimator based module file |
| 162 | + |
| 163 | +To convert the current estimator based module file (e.g., |
| 164 | +[iris](https://github.com/tensorflow/tfx/blob/r0.15/tfx/examples/iris/iris_utils.py)) |
| 165 | +for generic trainer, simply add a run_fn that calls the trainer_fn and train the |
| 166 | +returned model (code that used to be in the trainer.executor.Do). |
| 167 | + |
| 168 | +```python |
| 169 | +def run_fn(fn_args: executor.TrainerFnArgs): |
| 170 | + """Train the model based on given args. |
| 171 | +
|
| 172 | + Args: |
| 173 | + fn_args: Holds args used to train the model as name/value pairs. |
| 174 | + """ |
| 175 | + schema = io_utils.parse_pbtxt_file(fn_args.schema_file, schema_pb2.Schema()) |
| 176 | + |
| 177 | + # Reuse the trainer_fn. |
| 178 | + training_spec = trainer_fn(fn_args, schema) |
| 179 | + |
| 180 | + # Train the model |
| 181 | + absl.logging.info('Training model.') |
| 182 | + tf.estimator.train_and_evaluate(training_spec['estimator'], |
| 183 | + training_spec['train_spec'], |
| 184 | + training_spec['eval_spec']) |
| 185 | + absl.logging.info('Training complete. Model written to %s', |
| 186 | + fn_args.serving_model_dir) |
| 187 | + |
| 188 | + # Export an eval savedmodel for TFMA, note that for keras, eval savedmodel is |
| 189 | + # not needed as TFMA2 can use serving model for evaluation. |
| 190 | + absl.logging.info('Exporting eval_savedmodel for TFMA.') |
| 191 | + tfma.export.export_eval_savedmodel( |
| 192 | + estimator=training_spec['estimator'], |
| 193 | + export_dir_base=fn_args.eval_model_dir, |
| 194 | + eval_input_receiver_fn=training_spec['eval_input_receiver_fn']) |
| 195 | + |
| 196 | + absl.logging.info('Exported eval_savedmodel to %s.', fn_args.eval_model_dir) |
| 197 | +``` |
| 198 | + |
| 199 | +### tf.distribute.Strategy |
| 200 | + |
| 201 | +Distribution strategy will be user module's responsibilty with the new generic |
| 202 | +trainer interface. To use it, user needs to modify the `run_fn()` in the module |
| 203 | +file, below shows the pseudo code example for single worker and multi-worker |
| 204 | +distribute strategy. |
| 205 | + |
| 206 | +For single worker distribute strategy, you need to create an appropriate |
| 207 | +[tf.distribute.Strategy](https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy), |
| 208 | +and move the creation and compiling of Keras model inside `strategy.scope`: |
| 209 | + |
| 210 | +```python |
| 211 | +def run_fn(args: trainer.executor.TrainerFnArgs) -> None: |
| 212 | + """Build the TF model and train it.""" |
| 213 | + mirrored_strategy = tf.distribute.MirroredStrategy() |
| 214 | + with mirrored_strategy.scope(): |
| 215 | + model = _build_keras_model() |
| 216 | + model.fit(...) |
| 217 | + model.save(...) |
| 218 | +``` |
| 219 | + |
| 220 | +For multi-worker distribution strategy, the TFX Trainer does not have ability to |
| 221 | +spawn multi-worker cluster by |
| 222 | +[current executor](https://github.com/tensorflow/tfx/blob/r0.21/tfx/components/trainer/executor.py), |
| 223 | +hence not covered in the scope of this RFC. If the execution environment of an |
| 224 | +implementation of TFX Trainer has the ability to bring up the cluster of worker |
| 225 | +machines, and execute user funtion in the workers with correct |
| 226 | +[TF_CONFIG setup](https://www.tensorflow.org/guide/distributed_training#setting_up_tf_config_environment_variable), |
| 227 | +such as GCP AI Platform Training service via |
| 228 | +[extensions/google_cloud_ai_platform/trainer/executor.py](https://github.com/tensorflow/tfx/blob/r0.21/tfx/extensions/google_cloud_ai_platform/trainer/executor.py), |
| 229 | +the `run_fn()` would look like below: |
| 230 | + |
| 231 | +```python |
| 232 | +def _is_chief() -> bool: |
| 233 | + """Decide whether the current worker's role is chief.""" |
| 234 | + # Check TF_CONFIG (set by TFX when bring up the worker) in execution env. |
| 235 | + ... |
| 236 | + |
| 237 | +def run_fn(args: trainer.executor.TrainerFnArgs) -> None: |
| 238 | + """Build the TF model and train it.""" |
| 239 | + ps_strategy = tf.distribute.experimental.ParameterServerStrategy() |
| 240 | + with ps_strategy.scope(): |
| 241 | + model = _build_keras_model() |
| 242 | + model.fit(...) |
| 243 | + if _is_chief(): |
| 244 | + model.save(...) |
| 245 | +``` |
| 246 | + |
| 247 | +For details about `tf.distribute.Strategy`, please refer to |
| 248 | +[here](https://www.tensorflow.org/guide/distributed_training). |
| 249 | + |
| 250 | +## Future work |
| 251 | + |
| 252 | +* Examples for custom training loop. |
| 253 | +* Native support for multi-worker distribution. |
0 commit comments