Skip to content

Commit 00ad26e

Browse files
committed
convert.py: Experimental args bitrate tweaking
1 parent 49ac13c commit 00ad26e

File tree

5 files changed

+48
-8
lines changed

5 files changed

+48
-8
lines changed

exllamav3/conversion/convert_model.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
parser.add_argument("-img", "--image_dump", action = "store_true", help = "Save model tensors as images (saved to working directory)")
4141
parser.add_argument("-mcg", "--mcg_multiplier", type = str, default = None, help = "MCG multiplier - EXPERIMENTAL, DO NOT USE")
4242
parser.add_argument("-mul1", "--mul1_multiplier", type = str, default = None, help = "MUL1 multiplier - EXPERIMENTAL, DO NOT USE")
43+
parser.add_argument("-strat", "--strategy", type = str, default = None, help = "Modifiers for quantization strategy - EXPERIMENTAL")
4344

4445
group = parser.add_mutually_exclusive_group()
4546
group.add_argument("--out_scales", dest = "out_scales_", action = "store_true", help = "Always enable out channel scales (for debug purposes)")
@@ -154,6 +155,7 @@ def override(arg, can_override, default):
154155
("device_ratios", True, None),
155156
("mcg_multiplier", True, ""),
156157
("mul1_multiplier", True, ""),
158+
("strategy", False, ""),
157159
]:
158160
override(arg_, can_override if not args.override_anyway else True, default)
159161

@@ -233,6 +235,32 @@ def get_state_error(x, ref):
233235
return err.item(), cos, sq
234236

235237

238+
def mod_strategy(args, module, strategy, idx):
239+
mod_arg = args.get("strategy")
240+
if not mod_arg:
241+
return strategy
242+
243+
s_layers = [""] + mod_arg.split(";")
244+
if idx >= len(s_layers):
245+
return strategy
246+
247+
s = s_layers[idx]
248+
mod = {}
249+
while s:
250+
l, m = s[0], s[1]
251+
s = s[2:]
252+
mod[l] = int(m)
253+
254+
new_strategy = {}
255+
for key, bits in strategy.items():
256+
submodule = module.find_module(key)
257+
modifier = mod.get(submodule.qbits_mod_key, 0)
258+
new_strategy[key] = min(bits + modifier, 8)
259+
260+
# TODO: Automate this, also calculate overall increase in bitrate, track in job.json across resumes
261+
return new_strategy
262+
263+
236264
@torch.inference_mode()
237265
def main(args, job_state):
238266

@@ -281,6 +309,7 @@ def main(args, job_state):
281309
},
282310
job_state["surplus_bits"],
283311
)
312+
strategy = mod_strategy(args, module, strategy, idx)
284313
job_state["surplus_bits"] = surplus
285314

286315
# Slice module if necessary

exllamav3/modules/attn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,10 @@ def __init__(
174174
else:
175175
fkey, frange_q, frange_k, frange_v = None, None, None, None
176176

177-
self.q_proj = Linear(config, f"{key}.{key_q}", hidden_size, num_q_heads * head_dim, qmap = qmap + ".input", fkey = fkey, frange = frange_q)
178-
self.k_proj = Linear(config, f"{key}.{key_k}", hidden_size, num_kv_heads * head_dim, qmap = qmap + ".input", fkey = fkey, frange = frange_k)
179-
self.v_proj = Linear(config, f"{key}.{key_v}", hidden_size, num_kv_heads * head_dim, qmap = qmap + ".input", fkey = fkey, frange = frange_v)
180-
self.o_proj = Linear(config, f"{key}.{key_o}", num_q_heads * head_dim, hidden_size, qmap = qmap + ".o", out_dtype = out_dtype)
177+
self.q_proj = Linear(config, f"{key}.{key_q}", hidden_size, num_q_heads * head_dim, qmap = qmap + ".input", fkey = fkey, frange = frange_q, qbits_mod_key = "q")
178+
self.k_proj = Linear(config, f"{key}.{key_k}", hidden_size, num_kv_heads * head_dim, qmap = qmap + ".input", fkey = fkey, frange = frange_k, qbits_mod_key = "k")
179+
self.v_proj = Linear(config, f"{key}.{key_v}", hidden_size, num_kv_heads * head_dim, qmap = qmap + ".input", fkey = fkey, frange = frange_v, qbits_mod_key = "v")
180+
self.o_proj = Linear(config, f"{key}.{key_o}", num_q_heads * head_dim, hidden_size, qmap = qmap + ".o", out_dtype = out_dtype, qbits_mod_key = "o")
181181

182182
self.register_submodule(self.q_proj)
183183
self.register_submodule(self.k_proj)

exllamav3/modules/linear.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(
2424
qmap: str | None = None,
2525
alt_key: str | None = None,
2626
qbits_key: str = "bits",
27+
qbits_mod_key: str = "",
2728
fkey : str | None = None,
2829
frange: tuple[int, int] | None = None,
2930
caps: dict = None,
@@ -50,6 +51,7 @@ def __init__(
5051
self.first_out_feature = first_out_feature if first_out_feature is not None else 0
5152
self.inner = None
5253
self.qbits_key = qbits_key
54+
self.qbits_mod_key = qbits_mod_key
5355
self.fkey = fkey
5456
self.frange = frange
5557
self.quant_type = None

exllamav3/modules/mlp.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def __init__(
2828

2929
self.out_dtype = out_dtype
3030

31-
self.up = Linear(config, f"{key}.{key_up}", hidden_size, intermediate_size, qmap = qmap + ".up")
32-
self.down = Linear(config, f"{key}.{key_down}", intermediate_size, hidden_size, qmap = qmap + ".down")
31+
self.up = Linear(config, f"{key}.{key_up}", hidden_size, intermediate_size, qmap = qmap + ".up", qbits_mod_key = "u")
32+
self.down = Linear(config, f"{key}.{key_down}", intermediate_size, hidden_size, qmap = qmap + ".down", qbits_mod_key = "d")
3333

3434
self.register_submodule(self.up)
3535
self.register_submodule(self.down)
@@ -129,7 +129,8 @@ def __init__(
129129
fkey = fkey,
130130
frange = frange_gate,
131131
alt_key = a_key_g,
132-
out_dtype = self.interm_dtype
132+
out_dtype = self.interm_dtype,
133+
qbits_mod_key = "g"
133134
)
134135
up = Linear(
135136
config = config,
@@ -144,7 +145,8 @@ def __init__(
144145
fkey = fkey,
145146
frange = frange_up,
146147
alt_key = a_key_u,
147-
out_dtype = self.interm_dtype
148+
out_dtype = self.interm_dtype,
149+
qbits_mod_key = "u"
148150
)
149151
down = Linear(
150152
config = config,
@@ -159,6 +161,7 @@ def __init__(
159161
alt_key = a_key_d,
160162
out_dtype = self.out_dtype,
161163
allow_input_padding = True,
164+
qbits_mod_key = "d"
162165
)
163166

164167
self.ups.append(up)

exllamav3/modules/module.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,18 @@ def __init__(
3636
self.caps = {}
3737
self.qmap = qmap
3838
self.num_slices = 1
39+
self.qbits_mod_key = ""
3940

4041
def __iter__(self):
4142
yield self
4243
for module in self.modules:
4344
yield from module
4445

46+
def find_module(self, key: str):
47+
for module in self:
48+
if module.key == key:
49+
return module
50+
4551
def can_defer_load(self):
4652
if len(self.modules) == 0: return True
4753
return all(module.can_defer_load() for module in self.modules)

0 commit comments

Comments
 (0)