Skip to content

Commit c8d7d8f

Browse files
authored
test: zero-copy transformation from struct to bytes (#203)
* test: zero-copy tranformation from struct to bytes * test: more tests
1 parent 3bcd692 commit c8d7d8f

File tree

1 file changed

+118
-44
lines changed

1 file changed

+118
-44
lines changed

src/runtime/src/payload.rs

Lines changed: 118 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
//! services.
1717
1818
use crate::encoding::Encoding;
19-
use arrow::datatypes::Schema;
19+
use arrow::datatypes::{Schema, SchemaRef};
2020
use arrow::record_batch::RecordBatch;
2121
use arrow_flight::utils::{flight_data_from_arrow_batch, flight_data_to_arrow_batch};
2222
use arrow_flight::FlightData;
@@ -110,7 +110,7 @@ pub struct Payload {
110110
#[serde(with = "serde_bytes")]
111111
pub data: Vec<u8>,
112112
/// The subplan's schema.
113-
schema: Schema,
113+
schema: SchemaRef,
114114
/// The query's uuid.
115115
pub uuid: Uuid,
116116
/// Compress `DataFrame` to guarantee the total size
@@ -122,7 +122,7 @@ impl Default for Payload {
122122
fn default() -> Payload {
123123
Self {
124124
data: vec![],
125-
schema: Schema::empty(),
125+
schema: Arc::new(Schema::empty()),
126126
uuid: Uuid::default(),
127127
encoding: Encoding::default(),
128128
}
@@ -147,7 +147,7 @@ impl Payload {
147147
app_metadata: vec![],
148148
flight_descriptor: None,
149149
},
150-
Arc::new(payload.schema.clone()),
150+
payload.schema.clone(),
151151
&[],
152152
)
153153
.unwrap()
@@ -171,12 +171,7 @@ impl Payload {
171171
})
172172
.collect();
173173

174-
marshal2value(
175-
&data_frames,
176-
(*batches[0].schema()).clone(),
177-
uuid,
178-
Encoding::default(),
179-
)
174+
marshal2value(&data_frames, batches[0].schema(), uuid, Encoding::default())
180175
}
181176

182177
/// Convert record batch to payload for network transmission.
@@ -193,19 +188,14 @@ impl Payload {
193188
})
194189
.collect();
195190

196-
marshal2bytes(
197-
&data_frames,
198-
(*batches[0].schema()).clone(),
199-
uuid,
200-
Encoding::default(),
201-
)
191+
marshal2bytes(&data_frames, batches[0].schema(), uuid, Encoding::default())
202192
}
203193
}
204194

205195
/// Serialize `Payload` in cloud functions.
206196
pub fn marshal2value(
207197
data: &Vec<DataFrame>,
208-
schema: Schema,
198+
schema: SchemaRef,
209199
uuid: Uuid,
210200
encoding: Encoding,
211201
) -> Value {
@@ -234,7 +224,7 @@ pub fn marshal2value(
234224
/// Serialize `Payload` in cloud functions.
235225
pub fn marshal2bytes(
236226
data: &Vec<DataFrame>,
237-
schema: Schema,
227+
schema: SchemaRef,
238228
uuid: Uuid,
239229
encoding: Encoding,
240230
) -> bytes::Bytes {
@@ -282,6 +272,8 @@ mod tests {
282272
use arrow::csv;
283273
use arrow::datatypes::{DataType, Field, Schema};
284274
use arrow::json;
275+
use std::mem;
276+
use std::slice;
285277
use std::sync::Arc;
286278
use std::time::Instant;
287279

@@ -334,8 +326,7 @@ mod tests {
334326
assert_eq!(1856, flight_data_size);
335327
}
336328

337-
#[test]
338-
fn flight_data_compression_ratio_2() {
329+
fn init_batches() -> RecordBatch {
339330
let schema = Arc::new(Schema::new(vec![
340331
Field::new("tripduration", DataType::Utf8, false),
341332
Field::new("starttime", DataType::Utf8, false),
@@ -357,7 +348,12 @@ mod tests {
357348
let records: &[u8] =
358349
include_str!("../../test/data/JC-202011-citibike-tripdata.csv").as_bytes();
359350
let mut reader = csv::Reader::new(records, schema, true, None, 21275, None, None);
360-
let batch = reader.next().unwrap().unwrap();
351+
reader.next().unwrap().unwrap()
352+
}
353+
354+
#[test]
355+
fn flight_data_compression_ratio_2() {
356+
let batch = init_batches();
361357

362358
// Arrow RecordBatch (in-memory)
363359
let size: usize = batch
@@ -436,29 +432,7 @@ mod tests {
436432

437433
#[tokio::test]
438434
async fn serde_payload() -> Result<()> {
439-
let schema = Arc::new(Schema::new(vec![
440-
Field::new("tripduration", DataType::Utf8, false),
441-
Field::new("starttime", DataType::Utf8, false),
442-
Field::new("stoptime", DataType::Utf8, false),
443-
Field::new("start station id", DataType::Int32, false),
444-
Field::new("start station name", DataType::Utf8, false),
445-
Field::new("start station latitude", DataType::Float64, false),
446-
Field::new("start station longitude", DataType::Float64, false),
447-
Field::new("end station id", DataType::Int32, false),
448-
Field::new("end station name", DataType::Utf8, false),
449-
Field::new("end station latitude", DataType::Float64, false),
450-
Field::new("end station longitude", DataType::Float64, false),
451-
Field::new("bikeid", DataType::Int32, false),
452-
Field::new("usertype", DataType::Utf8, false),
453-
Field::new("birth year", DataType::Int32, false),
454-
Field::new("gender", DataType::Int8, false),
455-
]));
456-
457-
let records: &[u8] =
458-
include_str!("../../test/data/JC-202011-citibike-tripdata.csv").as_bytes();
459-
let mut reader = csv::Reader::new(records, schema, true, None, 21275, None, None);
460-
461-
let batches = vec![reader.next().unwrap().unwrap()];
435+
let batches = vec![init_batches()];
462436
let mut uuid_builder =
463437
UuidBuilder::new("SX72HzqFz1Qij4bP-00-2021-01-28T19:27:50.298504836", 10);
464438
let uuid = uuid_builder.next();
@@ -481,4 +455,104 @@ mod tests {
481455

482456
Ok(())
483457
}
458+
459+
#[tokio::test]
460+
async fn transmute_data_frames() -> Result<()> {
461+
#[repr(packed)]
462+
pub struct DataFrameStruct {
463+
/// Arrow Flight Data's header.
464+
header: Vec<u8>,
465+
/// Arrow Flight Data's body.
466+
body: Vec<u8>,
467+
}
468+
469+
let batch = init_batches();
470+
let schema = batch.schema();
471+
let batches = vec![batch.clone(), batch.clone(), batch];
472+
let mut uuid_builder =
473+
UuidBuilder::new("SX72HzqFz1Qij4bP-00-2021-01-28T19:27:50.298504836", 10);
474+
let uuid = uuid_builder.next();
475+
476+
let options = arrow::ipc::writer::IpcWriteOptions::default();
477+
let data_frames = (0..batches.len())
478+
.map(|i| {
479+
let (_, flight_data) = flight_data_from_arrow_batch(&batches[i], &options);
480+
DataFrameStruct {
481+
header: flight_data.data_header,
482+
body: flight_data.data_body,
483+
}
484+
})
485+
.collect::<Vec<DataFrameStruct>>();
486+
unsafe {
487+
println!(
488+
"transmute data - raw data: {}",
489+
data_frames[0].header.len() + data_frames[0].body.len(),
490+
);
491+
}
492+
493+
let p: *const DataFrameStruct = &data_frames[0];
494+
let p: *const u8 = p as *const u8;
495+
let d: &[u8] = unsafe { slice::from_raw_parts(p, mem::size_of::<DataFrameStruct>()) };
496+
497+
let (head, body, _tail) = unsafe { d.align_to::<DataFrameStruct>() };
498+
assert!(head.is_empty(), "Data was not aligned");
499+
let my_struct = &body[0];
500+
501+
unsafe {
502+
assert_eq!(data_frames[0].header.len(), (*my_struct).header.len());
503+
assert_eq!(data_frames[0].header, (*my_struct).header);
504+
assert_eq!(data_frames[0].body.len(), (*my_struct).body.len());
505+
assert_eq!(data_frames[0].body, (*my_struct).body);
506+
}
507+
508+
let encoding = Encoding::Zstd;
509+
// compress
510+
let now = Instant::now();
511+
let event: bytes::Bytes = serde_json::to_vec(&Payload {
512+
data: encoding.compress(&d),
513+
uuid: uuid.clone(),
514+
encoding: encoding.clone(),
515+
schema,
516+
})
517+
.unwrap()
518+
.into();
519+
println!(
520+
"transmute data - compression time: {} us",
521+
now.elapsed().as_micros()
522+
);
523+
println!(
524+
"transmute data - compressed data: {}, type: {:?}",
525+
event.len(),
526+
encoding
527+
);
528+
529+
// decompress
530+
let now = Instant::now();
531+
let payload: Payload = serde_json::from_slice(&event).unwrap();
532+
let de_uuid = payload.uuid.clone();
533+
let encoded = payload.encoding.decompress(&payload.data);
534+
535+
let (head, body, _tail) = unsafe { encoded.align_to::<DataFrameStruct>() };
536+
println!(
537+
"transmute data - decompression time: {} us",
538+
now.elapsed().as_micros()
539+
);
540+
541+
let de_struct = &body[0];
542+
assert!(head.is_empty(), "Data was not aligned");
543+
544+
unsafe {
545+
assert_eq!(data_frames[0].header.len(), (*de_struct).header.len());
546+
assert_eq!(data_frames[0].header, (*de_struct).header);
547+
assert_eq!(data_frames[0].body.len(), (*de_struct).body.len());
548+
assert_eq!(data_frames[0].body, (*de_struct).body);
549+
assert_eq!(uuid, de_uuid);
550+
println!(
551+
"transmute data - decompress raw data: {}",
552+
(*de_struct).header.len() + (*de_struct).body.len(),
553+
);
554+
}
555+
556+
Ok(())
557+
}
484558
}

0 commit comments

Comments
 (0)