Skip to content

Commit a08c0a3

Browse files
committed
Refactor JWT helper functions
Change-Id: I1cc08489a8a947e0fa977e43659ce3fb77356dd5 Reviewed-on: https://review.couchbase.org/c/testrunner/+/237236 Reviewed-by: Saimirra R <[email protected]> Tested-by: Samridh Anand <[email protected]>
1 parent 90e7551 commit a08c0a3

File tree

2 files changed

+157
-80
lines changed

2 files changed

+157
-80
lines changed

pytests/security/jwt_token_test.py

Lines changed: 21 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pytests.autofailovertests import AutoFailoverBaseTest
1818
from pytests.security.internal_user import InternalUser
1919
from pytests.security.external_user import ExternalUser
20+
from pytests.security.jwt_utils import JWTUtils
2021

2122
class JWTTokenTest(OnPremBaseTestCase):
2223
def setUp(self):
@@ -33,7 +34,8 @@ def setUp(self):
3334
self.subject = self.input.param("subject", "This is a subject")
3435
self.ttl = self.input.param("ttl", 300)
3536
self.nbf_seconds = self.input.param("nbf_seconds", 0)
36-
self.private_key, self.pub_key = self._generate_key_pair(algorithm=self.algorithm, key_size=self.key_size)
37+
self.jwt_utils = JWTUtils(log=self.log)
38+
self.private_key, self.pub_key = self.jwt_utils.generate_key_pair(algorithm=self.algorithm, key_size=self.key_size)
3739
self._enable_dev_preview()
3840

3941
def _enable_dev_preview(self):
@@ -49,92 +51,31 @@ def _enable_dev_preview(self):
4951
if err:
5052
self.fail("Failed to enable Developer Preview Mode")
5153

52-
def _generate_key_pair(self, algorithm: str, key_size: int):
53-
"""Generate a key pair for JWT signing.
54-
Args:
55-
algorithm: JWT signing algorithm (RS256, ES256, etc.)
56-
key_size: Key size in bits for RSA algorithms (minimum 2048)
57-
Returns:
58-
tuple: (private_key_pem, public_key_pem) as strings
59-
"""
60-
EC_CURVE_MAP = {
61-
"ES256": ec.SECP256R1(),
62-
"ES384": ec.SECP384R1(),
63-
"ES512": ec.SECP521R1(),
64-
"ES256K": ec.SECP256K1(),
65-
}
66-
algorithm = algorithm.upper()
67-
self.log.info(f"Generating key pair for algorithm: {algorithm}")
68-
if algorithm.startswith("RS") or algorithm.startswith("PS"):
69-
if key_size < 2048:
70-
self.fail(f"RSA key size must be 2048 or greater. Got {key_size}")
71-
private_key = rsa.generate_private_key(
72-
public_exponent=65537,
73-
key_size=key_size,
74-
)
75-
elif algorithm in EC_CURVE_MAP:
76-
curve = EC_CURVE_MAP[algorithm]
77-
private_key = ec.generate_private_key(curve)
78-
else:
79-
self.fail(f"Unsupported algorithm: {algorithm}. Supported: RS*, PS*, ES*")
80-
if not private_key:
81-
self.fail("Error while creating key pair")
82-
private_key_pem = private_key.private_bytes(
83-
encoding=serialization.Encoding.PEM,
84-
format=serialization.PrivateFormat.PKCS8,
85-
encryption_algorithm=serialization.NoEncryption()
86-
)
87-
public_key = private_key.public_key()
88-
public_key_pem = public_key.public_bytes(
89-
encoding=serialization.Encoding.PEM,
90-
format=serialization.PublicFormat.SubjectPublicKeyInfo
91-
)
92-
self.log.info("Key pair generated successfully.")
93-
return private_key_pem.decode(), public_key_pem.decode()
94-
9554
def _get_jwt_config(self, jit_provisioning=True):
9655
"""Get JWT configuration with configurable JIT provisioning
9756
Args:
9857
jit_provisioning (bool): Enable/disable JIT user provisioning. Default: True
9958
"""
100-
return {
101-
"enabled": True,
102-
"issuers": [
103-
{
104-
"name": self.issuer_name,
105-
"signingAlgorithm": self.algorithm,
106-
"publicKeySource": "pem",
107-
"publicKey": self.pub_key,
108-
"jitProvisioning": jit_provisioning,
109-
"subClaim": "sub",
110-
"audClaim": "aud",
111-
"audienceHandling": "any",
112-
"audiences": self.token_audience,
113-
"groupsClaim": "groups",
114-
"groupsMaps": self.token_group_matching_rule
115-
}
116-
]
117-
}
59+
return self.jwt_utils.get_jwt_config(
60+
issuer_name=self.issuer_name,
61+
algorithm=self.algorithm,
62+
pub_key=self.pub_key,
63+
token_audience=self.token_audience,
64+
token_group_matching_rule=self.token_group_matching_rule,
65+
jit_provisioning=jit_provisioning
66+
)
11867

11968
def create_token(self):
120-
curr_time = int(time.time())
121-
payload = {
122-
"iss": self.issuer_name,
123-
"sub": self.user_name,
124-
"exp": curr_time+self.ttl,
125-
"iat": curr_time,
126-
"nbf": curr_time-self.nbf_seconds,
127-
"jti": str(uuid.uuid4()),
128-
}
129-
if self.token_audience:
130-
payload['aud'] = self.token_audience
131-
if self.user_groups:
132-
payload['groups'] = self.user_groups
133-
self.log.info(f"Creating JWT token with payload: {json.dumps(payload, indent=2)}")
134-
jwt_token = jwt.encode(payload=payload,
135-
algorithm=self.algorithm,
136-
key=self.private_key)
137-
return jwt_token
69+
return self.jwt_utils.create_token(
70+
issuer_name=self.issuer_name,
71+
user_name=self.user_name,
72+
algorithm=self.algorithm,
73+
private_key=self.private_key,
74+
token_audience=self.token_audience,
75+
user_groups=self.user_groups,
76+
ttl=self.ttl,
77+
nbf_seconds=self.nbf_seconds
78+
)
13879

13980
def _debug_jwt_config(self, rest_conn):
14081
"""Debug method to check JWT configuration"""

pytests/security/jwt_utils.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import time
2+
import uuid
3+
import jwt
4+
import json
5+
from cryptography.hazmat.primitives import serialization
6+
from cryptography.hazmat.primitives.asymmetric import rsa, ec
7+
8+
9+
class JWTUtils:
10+
"""Utility class for JWT token generation and configuration"""
11+
12+
EC_CURVE_MAP = {
13+
"ES256": ec.SECP256R1(),
14+
"ES384": ec.SECP384R1(),
15+
"ES512": ec.SECP521R1(),
16+
"ES256K": ec.SECP256K1(),
17+
}
18+
19+
def __init__(self, log=None):
20+
"""Initialize JWTUtils
21+
Args:
22+
log: Logger instance for logging messages (optional)
23+
"""
24+
self.log = log
25+
26+
def generate_key_pair(self, algorithm: str, key_size: int = 2048):
27+
"""Generate a key pair for JWT signing.
28+
Args:
29+
algorithm: JWT signing algorithm (RS256, ES256, etc.)
30+
key_size: Key size in bits for RSA algorithms (minimum 2048)
31+
Returns:
32+
tuple: (private_key_pem, public_key_pem) as strings
33+
"""
34+
algorithm = algorithm.upper()
35+
if self.log:
36+
self.log.info(f"Generating key pair for algorithm: {algorithm}")
37+
38+
if algorithm.startswith("RS") or algorithm.startswith("PS"):
39+
if key_size < 2048:
40+
raise ValueError(f"RSA key size must be 2048 or greater. Got {key_size}")
41+
private_key = rsa.generate_private_key(
42+
public_exponent=65537,
43+
key_size=key_size,
44+
)
45+
elif algorithm in self.EC_CURVE_MAP:
46+
curve = self.EC_CURVE_MAP[algorithm]
47+
private_key = ec.generate_private_key(curve)
48+
else:
49+
raise ValueError(f"Unsupported algorithm: {algorithm}. Supported: RS*, PS*, ES*")
50+
51+
if not private_key:
52+
raise RuntimeError("Error while creating key pair")
53+
54+
private_key_pem = private_key.private_bytes(
55+
encoding=serialization.Encoding.PEM,
56+
format=serialization.PrivateFormat.PKCS8,
57+
encryption_algorithm=serialization.NoEncryption()
58+
)
59+
public_key = private_key.public_key()
60+
public_key_pem = public_key.public_bytes(
61+
encoding=serialization.Encoding.PEM,
62+
format=serialization.PublicFormat.SubjectPublicKeyInfo
63+
)
64+
if self.log:
65+
self.log.info("Key pair generated successfully.")
66+
return private_key_pem.decode(), public_key_pem.decode()
67+
68+
def get_jwt_config(self, issuer_name, algorithm, pub_key, token_audience,
69+
token_group_matching_rule, jit_provisioning=True):
70+
"""Get JWT configuration with configurable JIT provisioning
71+
Args:
72+
issuer_name: Name of the JWT issuer
73+
algorithm: JWT signing algorithm
74+
pub_key: Public key in PEM format
75+
token_audience: List of audiences for the token
76+
token_group_matching_rule: List of group matching rules
77+
jit_provisioning (bool): Enable/disable JIT user provisioning. Default: True
78+
Returns:
79+
dict: JWT configuration dictionary
80+
"""
81+
return {
82+
"enabled": True,
83+
"issuers": [
84+
{
85+
"name": issuer_name,
86+
"signingAlgorithm": algorithm,
87+
"publicKeySource": "pem",
88+
"publicKey": pub_key,
89+
"jitProvisioning": jit_provisioning,
90+
"subClaim": "sub",
91+
"audClaim": "aud",
92+
"audienceHandling": "any",
93+
"audiences": token_audience,
94+
"groupsClaim": "groups",
95+
"groupsMaps": token_group_matching_rule
96+
}
97+
]
98+
}
99+
100+
def create_token(self, issuer_name, user_name, algorithm, private_key,
101+
token_audience=None, user_groups=None, ttl=300,
102+
nbf_seconds=0):
103+
"""Create a JWT token with the specified parameters
104+
Args:
105+
issuer_name: Name of the JWT issuer
106+
user_name: Username for the token subject
107+
algorithm: JWT signing algorithm
108+
private_key: Private key for signing the token
109+
token_audience: List of audiences for the token (optional)
110+
user_groups: List of user groups (optional)
111+
ttl: Time to live in seconds (default: 300)
112+
nbf_seconds: Not before offset in seconds (default: 0)
113+
Returns:
114+
str: Encoded JWT token
115+
"""
116+
curr_time = int(time.time())
117+
payload = {
118+
"iss": issuer_name,
119+
"sub": user_name,
120+
"exp": curr_time + ttl,
121+
"iat": curr_time,
122+
"nbf": curr_time - nbf_seconds,
123+
"jti": str(uuid.uuid4()),
124+
}
125+
if token_audience:
126+
payload['aud'] = token_audience
127+
if user_groups:
128+
payload['groups'] = user_groups
129+
130+
if self.log:
131+
self.log.info(f"Creating JWT token with payload: {json.dumps(payload, indent=2)}")
132+
133+
jwt_token = jwt.encode(payload=payload,
134+
algorithm=algorithm,
135+
key=private_key)
136+
return jwt_token

0 commit comments

Comments
 (0)