Skip to content

Commit 66db9c3

Browse files
committed
handle partition_by option
1 parent fee4257 commit 66db9c3

File tree

5 files changed

+47
-14
lines changed

5 files changed

+47
-14
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ KinesisClient.Stream.start_link(opts)
6060
`MyShardConsumer` needs to implement the `Broadway` behaviour. You will want to
6161
start the `KinesisClient.Stream` in your application's supervision tree.
6262

63+
## partition_by option
64+
65+
If you want to include the partition_by option to the Broadway pipeline
66+
then you need to implement a partition_by/1 function in the consumer `MyShardConsumer`.
67+
6368

6469
## Things to keep in mind...
6570

lib/kinesis_client/stream/shard/pipeline.ex

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,18 @@ defmodule KinesisClient.Stream.Shard.Pipeline do
4646
|> Keyword.get(:pipeline_context, %{})
4747
|> Map.put(:shard_consumer, opts[:shard_consumer])
4848

49-
pipeline_opts = [
50-
name: register_name(__MODULE__, opts[:app_name], opts[:stream_name], [opts[:shard_id]]),
51-
producer: [
52-
module: {Producer, producer_opts},
53-
concurrency: 1
54-
],
55-
context: pipeline_context,
56-
processors: processor_opts,
57-
batchers: batcher_opts
58-
]
59-
60-
pipeline_opts = optional_kw(pipeline_opts, :partition_by, Keyword.get(opts, :partition_by))
49+
pipeline_opts =
50+
[
51+
name: register_name(__MODULE__, opts[:app_name], opts[:stream_name], [opts[:shard_id]]),
52+
producer: [
53+
module: {Producer, producer_opts},
54+
concurrency: 1
55+
],
56+
context: pipeline_context,
57+
processors: processor_opts,
58+
batchers: batcher_opts
59+
]
60+
|> partition_by(opts)
6161

6262
Broadway.start_link(__MODULE__, pipeline_opts)
6363
end
@@ -108,6 +108,17 @@ defmodule KinesisClient.Stream.Shard.Pipeline do
108108
end
109109
end
110110

111+
@impl Broadway
112+
def prepare_messages(messages, ctx) do
113+
module = Map.get(ctx, :shard_consumer)
114+
115+
if function_exported?(module, :prepare_messages, 2) do
116+
module.prepare_messages(messages, ctx)
117+
else
118+
messages
119+
end
120+
end
121+
111122
@impl Broadway
112123
def handle_message(processor, msg, ctx) do
113124
module = Map.get(ctx, :shard_consumer)
@@ -127,4 +138,16 @@ defmodule KinesisClient.Stream.Shard.Pipeline do
127138

128139
module.handle_failed(messages, context)
129140
end
141+
142+
defp partition_by(pipeline_opts, opts) do
143+
shard_consumer =
144+
opts[:shard_consumer]
145+
|> Code.ensure_loaded!()
146+
147+
if function_exported?(shard_consumer, :partition_by, 1) do
148+
Keyword.put(pipeline_opts, :partition_by, &shard_consumer.partition_by/1)
149+
else
150+
pipeline_opts
151+
end
152+
end
130153
end

test/kinesis_client/stream/coordinator_test.exs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ defmodule KinesisClient.Stream.CoordinatorTest do
249249
notify_pid: self(),
250250
kinesis_opts: [adapter: KinesisClient.KinesisMock],
251251
shard_args: [
252+
shard_consumer: __MODULE__,
252253
app_name: app_name,
253254
coordinator_name: coordinator_name,
254255
lease_owner: worker_ref(),

test/kinesis_client/stream/shard/lease_test.exs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ defmodule KinesisClient.Stream.Shard.LeaseTest do
2727

2828
assert_receive {:initialized, lease_state}, 1_000
2929

30-
inspect(lease_state, label: "lease_state")
3130
assert lease_state.lease_holder == true
3231
assert lease_state.lease_count == 1
3332

test/kinesis_client/stream/shard/pipeline_test.exs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,12 @@ defmodule KinesisClient.Stream.Shard.PipelineTest do
7171
stream_name = "pipeline-test-stream"
7272
shard_id = "shard-1"
7373

74-
opts = [app_name: app_name, stream_name: stream_name, shard_id: shard_id]
74+
opts = [
75+
app_name: app_name,
76+
stream_name: stream_name,
77+
shard_id: shard_id,
78+
shard_consumer: __MODULE__
79+
]
7580

7681
{:ok, pid} = start_supervised({Pipeline, opts})
7782

0 commit comments

Comments
 (0)