Skip to content

Commit 31c1c41

Browse files
authored
Nexus: evolve link conversion and query param processing (#953)
- accept both NexusOperationScheduled and EVENT_TYPE_NEXUS_OPERATION_SCHEDULED event type names in query params - always emit pascal case (e.g. WorkflowExecutionStarted) when returning links to the caller workflow - evolve test suite so that it can be used against external server
1 parent bdddb07 commit 31c1c41

10 files changed

+387
-142
lines changed

temporalio/nexus/_link_conversion.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
import re
5+
import urllib.parse
6+
from typing import (
7+
Any,
8+
Optional,
9+
)
10+
11+
import nexusrpc
12+
13+
import temporalio.api.common.v1
14+
import temporalio.api.enums.v1
15+
import temporalio.client
16+
17+
logger = logging.getLogger(__name__)
18+
19+
_LINK_URL_PATH_REGEX = re.compile(
20+
r"^/namespaces/(?P<namespace>[^/]+)/workflows/(?P<workflow_id>[^/]+)/(?P<run_id>[^/]+)/history$"
21+
)
22+
LINK_EVENT_ID_PARAM_NAME = "eventID"
23+
LINK_EVENT_TYPE_PARAM_NAME = "eventType"
24+
25+
26+
def workflow_handle_to_workflow_execution_started_event_link(
27+
handle: temporalio.client.WorkflowHandle[Any, Any],
28+
) -> temporalio.api.common.v1.Link.WorkflowEvent:
29+
"""Create a WorkflowEvent link corresponding to a started workflow"""
30+
if handle.first_execution_run_id is None:
31+
raise ValueError(
32+
f"Workflow handle {handle} has no first execution run ID. "
33+
f"Cannot create WorkflowExecutionStarted event link."
34+
)
35+
return temporalio.api.common.v1.Link.WorkflowEvent(
36+
namespace=handle._client.namespace,
37+
workflow_id=handle.id,
38+
run_id=handle.first_execution_run_id,
39+
event_ref=temporalio.api.common.v1.Link.WorkflowEvent.EventReference(
40+
event_id=1,
41+
event_type=temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED,
42+
),
43+
# TODO(nexus-preview): RequestIdReference
44+
)
45+
46+
47+
def workflow_event_to_nexus_link(
48+
workflow_event: temporalio.api.common.v1.Link.WorkflowEvent,
49+
) -> nexusrpc.Link:
50+
"""Convert a WorkflowEvent link into a nexusrpc link
51+
52+
Used when propagating links from a StartWorkflow response to a Nexus start operation
53+
response.
54+
"""
55+
scheme = "temporal"
56+
namespace = urllib.parse.quote(workflow_event.namespace)
57+
workflow_id = urllib.parse.quote(workflow_event.workflow_id)
58+
run_id = urllib.parse.quote(workflow_event.run_id)
59+
path = f"/namespaces/{namespace}/workflows/{workflow_id}/{run_id}/history"
60+
query_params = _event_reference_to_query_params(workflow_event.event_ref)
61+
return nexusrpc.Link(
62+
url=urllib.parse.urlunparse((scheme, "", path, "", query_params, "")),
63+
type=workflow_event.DESCRIPTOR.full_name,
64+
)
65+
66+
67+
def nexus_link_to_workflow_event(
68+
link: nexusrpc.Link,
69+
) -> Optional[temporalio.api.common.v1.Link.WorkflowEvent]:
70+
"""Convert a nexus link into a WorkflowEvent link
71+
72+
This is used when propagating links from a Nexus start operation request to a
73+
StartWorklow request.
74+
"""
75+
url = urllib.parse.urlparse(link.url)
76+
match = _LINK_URL_PATH_REGEX.match(url.path)
77+
if not match:
78+
logger.warning(
79+
f"Invalid Nexus link: {link}. Expected path to match {_LINK_URL_PATH_REGEX.pattern}"
80+
)
81+
return None
82+
try:
83+
event_ref = _query_params_to_event_reference(url.query)
84+
except ValueError as err:
85+
logger.warning(
86+
f"Failed to parse event reference from Nexus link URL query parameters: {link} ({err})"
87+
)
88+
return None
89+
90+
groups = match.groupdict()
91+
return temporalio.api.common.v1.Link.WorkflowEvent(
92+
namespace=urllib.parse.unquote(groups["namespace"]),
93+
workflow_id=urllib.parse.unquote(groups["workflow_id"]),
94+
run_id=urllib.parse.unquote(groups["run_id"]),
95+
event_ref=event_ref,
96+
)
97+
98+
99+
def _event_reference_to_query_params(
100+
event_ref: temporalio.api.common.v1.Link.WorkflowEvent.EventReference,
101+
) -> str:
102+
event_type_name = temporalio.api.enums.v1.EventType.Name(event_ref.event_type)
103+
if event_type_name.startswith("EVENT_TYPE_"):
104+
event_type_name = _event_type_constant_case_to_pascal_case(
105+
event_type_name.removeprefix("EVENT_TYPE_")
106+
)
107+
return urllib.parse.urlencode(
108+
{
109+
"eventID": event_ref.event_id,
110+
"eventType": event_type_name,
111+
"referenceType": "EventReference",
112+
}
113+
)
114+
115+
116+
def _query_params_to_event_reference(
117+
raw_query_params: str,
118+
) -> temporalio.api.common.v1.Link.WorkflowEvent.EventReference:
119+
"""Return an EventReference from the query params or raise ValueError."""
120+
query_params = urllib.parse.parse_qs(raw_query_params)
121+
122+
[reference_type] = query_params.get("referenceType") or [""]
123+
if reference_type != "EventReference":
124+
raise ValueError(
125+
f"Expected Nexus link URL query parameter referenceType to be EventReference but got: {reference_type}"
126+
)
127+
# event type
128+
[raw_event_type_name] = query_params.get(LINK_EVENT_TYPE_PARAM_NAME) or [""]
129+
if not raw_event_type_name:
130+
raise ValueError(f"query params do not contain event type: {query_params}")
131+
if raw_event_type_name.startswith("EVENT_TYPE_"):
132+
event_type_name = raw_event_type_name
133+
elif re.match("[A-Z][a-z]", raw_event_type_name):
134+
event_type_name = "EVENT_TYPE_" + _event_type_pascal_case_to_constant_case(
135+
raw_event_type_name
136+
)
137+
else:
138+
raise ValueError(f"Invalid event type name: {raw_event_type_name}")
139+
140+
# event id
141+
event_id = 0
142+
[raw_event_id] = query_params.get(LINK_EVENT_ID_PARAM_NAME) or [""]
143+
if raw_event_id:
144+
try:
145+
event_id = int(raw_event_id)
146+
except ValueError:
147+
raise ValueError(f"Query params contain invalid event id: {raw_event_id}")
148+
149+
return temporalio.api.common.v1.Link.WorkflowEvent.EventReference(
150+
event_type=temporalio.api.enums.v1.EventType.Value(event_type_name),
151+
event_id=event_id,
152+
)
153+
154+
155+
def _event_type_constant_case_to_pascal_case(s: str) -> str:
156+
"""Convert a CONSTANT_CASE string to PascalCase.
157+
158+
>>> _event_type_constant_case_to_pascal_case("NEXUS_OPERATION_SCHEDULED")
159+
"NexusOperationScheduled"
160+
"""
161+
return re.sub(r"(\b|_)([a-z])", lambda m: m.groups()[1].upper(), s.lower())
162+
163+
164+
def _event_type_pascal_case_to_constant_case(s: str) -> str:
165+
"""Convert a PascalCase string to CONSTANT_CASE.
166+
167+
>>> _event_type_pascal_case_to_constant_case("NexusOperationScheduled")
168+
"NEXUS_OPERATION_SCHEDULED"
169+
"""
170+
return re.sub(r"([A-Z])", r"_\1", s).lstrip("_").upper()

temporalio/nexus/_operation_context.py

Lines changed: 4 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
import dataclasses
44
import logging
5-
import re
6-
import urllib.parse
75
from contextvars import ContextVar
86
from dataclasses import dataclass
97
from datetime import timedelta
@@ -20,14 +18,13 @@
2018
overload,
2119
)
2220

23-
import nexusrpc.handler
2421
from nexusrpc.handler import CancelOperationContext, StartOperationContext
2522
from typing_extensions import Concatenate
2623

2724
import temporalio.api.common.v1
28-
import temporalio.api.enums.v1
2925
import temporalio.client
3026
import temporalio.common
27+
from temporalio.nexus import _link_conversion
3128
from temporalio.nexus._token import WorkflowHandle
3229
from temporalio.types import (
3330
MethodAsyncNoParam,
@@ -128,11 +125,6 @@ def _get_callbacks(
128125
ctx = self.nexus_context
129126
return (
130127
[
131-
# TODO(nexus-prerelease): For WorkflowRunOperation, when it handles the Nexus
132-
# request, it needs to copy the links to the callback in
133-
# StartWorkflowRequest.CompletionCallbacks and to StartWorkflowRequest.Links
134-
# (for backwards compatibility). PR reference in Go SDK:
135-
# https://github.com/temporalio/sdk-go/pull/1945
136128
temporalio.client.NexusCallback(
137129
url=ctx.callback_url,
138130
headers=ctx.callback_headers,
@@ -147,16 +139,16 @@ def _get_workflow_event_links(
147139
) -> list[temporalio.api.common.v1.Link.WorkflowEvent]:
148140
event_links = []
149141
for inbound_link in self.nexus_context.inbound_links:
150-
if link := _nexus_link_to_workflow_event(inbound_link):
142+
if link := _link_conversion.nexus_link_to_workflow_event(inbound_link):
151143
event_links.append(link)
152144
return event_links
153145

154146
def _add_outbound_links(
155147
self, workflow_handle: temporalio.client.WorkflowHandle[Any, Any]
156148
):
157149
try:
158-
link = _workflow_event_to_nexus_link(
159-
_workflow_handle_to_workflow_execution_started_event_link(
150+
link = _link_conversion.workflow_event_to_nexus_link(
151+
_link_conversion.workflow_handle_to_workflow_execution_started_event_link(
160152
workflow_handle
161153
)
162154
)
@@ -479,91 +471,6 @@ def set(self) -> None:
479471
_temporal_cancel_operation_context.set(self)
480472

481473

482-
def _workflow_handle_to_workflow_execution_started_event_link(
483-
handle: temporalio.client.WorkflowHandle[Any, Any],
484-
) -> temporalio.api.common.v1.Link.WorkflowEvent:
485-
if handle.first_execution_run_id is None:
486-
raise ValueError(
487-
f"Workflow handle {handle} has no first execution run ID. "
488-
"Cannot create WorkflowExecutionStarted event link."
489-
)
490-
return temporalio.api.common.v1.Link.WorkflowEvent(
491-
namespace=handle._client.namespace,
492-
workflow_id=handle.id,
493-
run_id=handle.first_execution_run_id,
494-
event_ref=temporalio.api.common.v1.Link.WorkflowEvent.EventReference(
495-
event_id=1,
496-
event_type=temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED,
497-
),
498-
# TODO(nexus-prerelease): RequestIdReference?
499-
)
500-
501-
502-
def _workflow_event_to_nexus_link(
503-
workflow_event: temporalio.api.common.v1.Link.WorkflowEvent,
504-
) -> nexusrpc.Link:
505-
scheme = "temporal"
506-
namespace = urllib.parse.quote(workflow_event.namespace)
507-
workflow_id = urllib.parse.quote(workflow_event.workflow_id)
508-
run_id = urllib.parse.quote(workflow_event.run_id)
509-
path = f"/namespaces/{namespace}/workflows/{workflow_id}/{run_id}/history"
510-
query_params = urllib.parse.urlencode(
511-
{
512-
"eventType": temporalio.api.enums.v1.EventType.Name(
513-
workflow_event.event_ref.event_type
514-
),
515-
"referenceType": "EventReference",
516-
}
517-
)
518-
return nexusrpc.Link(
519-
url=urllib.parse.urlunparse((scheme, "", path, "", query_params, "")),
520-
type=workflow_event.DESCRIPTOR.full_name,
521-
)
522-
523-
524-
_LINK_URL_PATH_REGEX = re.compile(
525-
r"^/namespaces/(?P<namespace>[^/]+)/workflows/(?P<workflow_id>[^/]+)/(?P<run_id>[^/]+)/history$"
526-
)
527-
528-
529-
def _nexus_link_to_workflow_event(
530-
link: nexusrpc.Link,
531-
) -> Optional[temporalio.api.common.v1.Link.WorkflowEvent]:
532-
url = urllib.parse.urlparse(link.url)
533-
match = _LINK_URL_PATH_REGEX.match(url.path)
534-
if not match:
535-
logger.warning(
536-
f"Invalid Nexus link: {link}. Expected path to match {_LINK_URL_PATH_REGEX.pattern}"
537-
)
538-
return None
539-
try:
540-
query_params = urllib.parse.parse_qs(url.query)
541-
[reference_type] = query_params.get("referenceType", [])
542-
if reference_type != "EventReference":
543-
raise ValueError(
544-
f"Expected Nexus link URL query parameter referenceType to be EventReference but got: {reference_type}"
545-
)
546-
[event_type_name] = query_params.get("eventType", [])
547-
event_ref = temporalio.api.common.v1.Link.WorkflowEvent.EventReference(
548-
# TODO(nexus-prerelease): confirm that it is correct not to use event_id.
549-
# Should the proto say explicitly that it's optional or how it behaves when it's missing?
550-
event_type=temporalio.api.enums.v1.EventType.Value(event_type_name)
551-
)
552-
except ValueError as err:
553-
logger.warning(
554-
f"Failed to parse event type from Nexus link URL query parameters: {link} ({err})"
555-
)
556-
event_ref = None
557-
558-
groups = match.groupdict()
559-
return temporalio.api.common.v1.Link.WorkflowEvent(
560-
namespace=urllib.parse.unquote(groups["namespace"]),
561-
workflow_id=urllib.parse.unquote(groups["workflow_id"]),
562-
run_id=urllib.parse.unquote(groups["run_id"]),
563-
event_ref=event_ref,
564-
)
565-
566-
567474
class LoggerAdapter(logging.LoggerAdapter):
568475
"""Logger adapter that adds Nexus operation context information."""
569476

tests/helpers/nexus.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import dataclasses
22
from dataclasses import dataclass
33
from typing import Any, Mapping, Optional
4+
from urllib.parse import urlparse
45

56
import temporalio.api.failure.v1
67
import temporalio.api.nexus.v1
78
import temporalio.api.operatorservice.v1
89
import temporalio.workflow
910
from temporalio.client import Client
1011
from temporalio.converter import FailureConverter, PayloadConverter
12+
from temporalio.testing import WorkflowEnvironment
1113

1214
with temporalio.workflow.unsafe.imports_passed_through():
1315
import httpx
@@ -58,7 +60,7 @@ async def start_operation(
5860
# TODO(nexus-preview): Support callback URL as query param
5961
async with httpx.AsyncClient() as http_client:
6062
return await http_client.post(
61-
f"{self.server_address}/nexus/endpoints/{self.endpoint}/services/{self.service}/{operation}",
63+
f"http://{self.server_address}/nexus/endpoints/{self.endpoint}/services/{self.service}/{operation}",
6264
json=body,
6365
headers=headers,
6466
)
@@ -70,11 +72,20 @@ async def cancel_operation(
7072
) -> httpx.Response:
7173
async with httpx.AsyncClient() as http_client:
7274
return await http_client.post(
73-
f"{self.server_address}/nexus/endpoints/{self.endpoint}/services/{self.service}/{operation}/cancel",
75+
f"http://{self.server_address}/nexus/endpoints/{self.endpoint}/services/{self.service}/{operation}/cancel",
7476
# Token can also be sent as "Nexus-Operation-Token" header
7577
params={"token": token},
7678
)
7779

80+
@staticmethod
81+
def default_server_address(env: WorkflowEnvironment) -> str:
82+
# TODO(nexus-preview): nexus tests are making http requests directly but this is
83+
# not officially supported.
84+
parsed = urlparse(env.client.service_client.config.target_host)
85+
host = parsed.hostname or "127.0.0.1"
86+
http_port = getattr(env, "_http_port", 7243)
87+
return f"{host}:{http_port}"
88+
7889

7990
def dataclass_as_dict(dataclass: Any) -> dict[str, Any]:
8091
"""

0 commit comments

Comments
 (0)