|
| 1 | +import json |
1 | 2 | import re
|
| 3 | +from enum import Enum |
| 4 | +from typing import List, Union |
2 | 5 |
|
3 | 6 | import pytest
|
4 | 7 | import torch
|
| 8 | +from pydantic import BaseModel, constr |
5 | 9 |
|
6 | 10 | import outlines.models as models
|
7 | 11 | import outlines.text.generate as generate
|
@@ -113,3 +117,111 @@ def test_transformers_integration_with_pad_token():
|
113 | 117 | model = models.transformers(model_name, device="cpu")
|
114 | 118 | assert model.tokenizer.pad_token_id == 1
|
115 | 119 | assert model.tokenizer.pad_token == "<pad>"
|
| 120 | + |
| 121 | + |
| 122 | +def test_transformers_json_basic(): |
| 123 | + model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" |
| 124 | + model = models.transformers(model_name, device="cpu") |
| 125 | + prompt = "Output some JSON " |
| 126 | + |
| 127 | + class Spam(BaseModel): |
| 128 | + foo: int |
| 129 | + bar: float |
| 130 | + spam: constr(max_length=10) |
| 131 | + fuzz: bool |
| 132 | + |
| 133 | + rng = torch.Generator() |
| 134 | + rng.manual_seed(0) # make sure that `bar` is not an int |
| 135 | + |
| 136 | + sequence = generate.json(model, Spam, max_tokens=1000)(prompt, rng=rng) |
| 137 | + parsed = json.loads(sequence) |
| 138 | + assert isinstance(parsed["foo"], int) |
| 139 | + assert isinstance(parsed["bar"], float) |
| 140 | + assert isinstance(parsed["spam"], str) |
| 141 | + assert isinstance(parsed["fuzz"], bool) |
| 142 | + assert len(parsed["spam"]) == 10 |
| 143 | + |
| 144 | + |
| 145 | +def test_transformers_json_str_enum(): |
| 146 | + model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" |
| 147 | + model = models.transformers(model_name, device="cpu") |
| 148 | + prompt = "Output some JSON " |
| 149 | + |
| 150 | + rng = torch.Generator() |
| 151 | + rng.manual_seed(0) |
| 152 | + |
| 153 | + class Name(str, Enum): |
| 154 | + john = "John" |
| 155 | + marc = "Marc" |
| 156 | + michel = "Michel" |
| 157 | + |
| 158 | + class User(BaseModel): |
| 159 | + user_id: int |
| 160 | + name: Name |
| 161 | + |
| 162 | + sequence = generate.json(model, User)(prompt, rng=rng) |
| 163 | + parsed = json.loads(sequence) |
| 164 | + assert isinstance(parsed["user_id"], int) |
| 165 | + assert parsed["name"] in ["John", "Marc", "Michel"] |
| 166 | + |
| 167 | + |
| 168 | +def test_transformers_json_int_enum(): |
| 169 | + model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" |
| 170 | + model = models.transformers(model_name, device="cpu") |
| 171 | + prompt = "Output some JSON " |
| 172 | + |
| 173 | + rng = torch.Generator() |
| 174 | + rng.manual_seed(0) |
| 175 | + |
| 176 | + class Id(int, Enum): |
| 177 | + one = 1 |
| 178 | + two = 2 |
| 179 | + |
| 180 | + class User(BaseModel): |
| 181 | + user_id: Id |
| 182 | + |
| 183 | + sequence = generate.json(model, User)(prompt, rng=rng) |
| 184 | + parsed = json.loads(sequence) |
| 185 | + assert isinstance(parsed["user_id"], int) |
| 186 | + assert parsed["user_id"] in [1, 2] |
| 187 | + |
| 188 | + |
| 189 | +def test_transformers_json_array(): |
| 190 | + model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" |
| 191 | + model = models.transformers(model_name, device="cpu") |
| 192 | + prompt = "Output some JSON " |
| 193 | + |
| 194 | + class User(BaseModel): |
| 195 | + user_id: int |
| 196 | + value: List[float] |
| 197 | + |
| 198 | + rng = torch.Generator() |
| 199 | + rng.manual_seed(0) |
| 200 | + |
| 201 | + sequence = generate.json(model, User)(prompt, rng=rng) |
| 202 | + parsed = json.loads(sequence) |
| 203 | + assert isinstance(parsed["user_id"], int) |
| 204 | + assert isinstance(parsed["value"], list) |
| 205 | + for value in parsed["value"]: |
| 206 | + assert isinstance(value, float) or isinstance(value, int) |
| 207 | + |
| 208 | + |
| 209 | +def test_transformers_json_union(): |
| 210 | + model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" |
| 211 | + model = models.transformers(model_name, device="cpu") |
| 212 | + prompt = "Output some JSON " |
| 213 | + |
| 214 | + class Spam(BaseModel): |
| 215 | + foo: int |
| 216 | + bar: Union[constr(max_length=10), float] |
| 217 | + |
| 218 | + rng = torch.Generator() |
| 219 | + rng.manual_seed(4) |
| 220 | + |
| 221 | + sequence = generate.json(model, Spam, max_tokens=100)(prompt, rng=rng) |
| 222 | + parsed = json.loads(sequence) |
| 223 | + assert ( |
| 224 | + isinstance(parsed["bar"], int) |
| 225 | + or isinstance(parsed["bar"], float) |
| 226 | + or isinstance(parsed["bar"], str) |
| 227 | + ) |
0 commit comments