1
+ import os
2
+ from dotenv import load_dotenv
3
+ from utils import count_physical_cores
4
+ from torch .cuda import device_count
5
+
6
+ class EngineConfig :
7
+ def __init__ (self ):
8
+ load_dotenv ()
9
+ self .model_name_or_path , self .hf_home , self .model_revision = self ._get_local_or_env ("/local_model_path.txt" , "MODEL_NAME" )
10
+ self .tokenizer_name_or_path , _ , self .tokenizer_revision = self ._get_local_or_env ("/local_tokenizer_path.txt" , "TOKENIZER_NAME" )
11
+ self .tokenizer_name_or_path = self .tokenizer_name_or_path or self .model_name_or_path
12
+ self .quantization = self ._get_quantization ()
13
+ self .config = self ._initialize_config ()
14
+
15
+ def _get_local_or_env (self , local_path , env_var ):
16
+ if os .path .exists (local_path ):
17
+ with open (local_path , "r" ) as file :
18
+ return file .read ().strip (), None , None
19
+ return os .getenv (env_var ), os .getenv ("HF_HOME" ), os .getenv (f"{ env_var } _REVISION" )
20
+
21
+ def _get_quantization (self ):
22
+ quantization = os .getenv ("QUANTIZATION" , "" ).lower ()
23
+ return quantization if quantization in ["awq" , "squeezellm" , "gptq" ] else None
24
+
25
+ def _initialize_config (self ):
26
+ args = {
27
+ "model" : self .model_name_or_path ,
28
+ "revision" : self .model_revision ,
29
+ "download_dir" : self .hf_home ,
30
+ "quantization" : self .quantization ,
31
+ "load_format" : os .getenv ("LOAD_FORMAT" , "auto" ),
32
+ "dtype" : os .getenv ("DTYPE" , "half" if self .quantization else "auto" ),
33
+ "tokenizer" : self .tokenizer_name_or_path ,
34
+ "tokenizer_revision" : self .tokenizer_revision ,
35
+ "disable_log_stats" : bool (int (os .getenv ("DISABLE_LOG_STATS" , 1 ))),
36
+ "disable_log_requests" : bool (int (os .getenv ("DISABLE_LOG_REQUESTS" , 1 ))),
37
+ "trust_remote_code" : bool (int (os .getenv ("TRUST_REMOTE_CODE" , 0 ))),
38
+ "gpu_memory_utilization" : float (os .getenv ("GPU_MEMORY_UTILIZATION" , 0.95 )),
39
+ "max_parallel_loading_workers" : None if device_count () > 1 or not os .getenv ("MAX_PARALLEL_LOADING_WORKERS" ) else int (os .getenv ("MAX_PARALLEL_LOADING_WORKERS" )),
40
+ "max_model_len" : int (os .getenv ("MAX_MODEL_LENGTH" )) if os .getenv ("MAX_MODEL_LENGTH" ) else None ,
41
+ "tensor_parallel_size" : device_count (),
42
+ "seed" : int (os .getenv ("SEED" )) if os .getenv ("SEED" ) else None ,
43
+ "kv_cache_dtype" : os .getenv ("KV_CACHE_DTYPE" ),
44
+ "block_size" : int (os .getenv ("BLOCK_SIZE" )) if os .getenv ("BLOCK_SIZE" ) else None ,
45
+ "swap_space" : int (os .getenv ("SWAP_SPACE" )) if os .getenv ("SWAP_SPACE" ) else None ,
46
+ "max_context_len_to_capture" : int (os .getenv ("MAX_CONTEXT_LEN_TO_CAPTURE" )) if os .getenv ("MAX_CONTEXT_LEN_TO_CAPTURE" ) else None ,
47
+ "disable_custom_all_reduce" : bool (int (os .getenv ("DISABLE_CUSTOM_ALL_REDUCE" , 0 ))),
48
+ "enforce_eager" : bool (int (os .getenv ("ENFORCE_EAGER" , 0 )))
49
+ }
50
+
51
+ return {k : v for k , v in args .items () if v is not None }
0 commit comments