Skip to content

Commit a154cf3

Browse files
authored
add custom json coder support (#285)
1 parent c4d32e3 commit a154cf3

File tree

8 files changed

+153
-93
lines changed

8 files changed

+153
-93
lines changed

Sources/MySQLKit/MySQLDatabase.swift renamed to Sources/MySQLKit/MySQLConfiguration.swift

Lines changed: 0 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -79,70 +79,3 @@ public struct MySQLConfiguration {
7979
self._hostname = hostname
8080
}
8181
}
82-
83-
public struct MySQLConnectionSource: ConnectionPoolSource {
84-
public let configuration: MySQLConfiguration
85-
86-
public init(configuration: MySQLConfiguration) {
87-
self.configuration = configuration
88-
}
89-
90-
public func makeConnection(logger: Logger, on eventLoop: EventLoop) -> EventLoopFuture<MySQLConnection> {
91-
let address: SocketAddress
92-
do {
93-
address = try self.configuration.address()
94-
} catch {
95-
return eventLoop.makeFailedFuture(error)
96-
}
97-
return MySQLConnection.connect(
98-
to: address,
99-
username: self.configuration.username,
100-
database: self.configuration.database ?? self.configuration.username,
101-
password: self.configuration.password,
102-
tlsConfiguration: self.configuration.tlsConfiguration,
103-
logger: logger,
104-
on: eventLoop
105-
)
106-
}
107-
}
108-
109-
extension MySQLConnection: ConnectionPoolItem { }
110-
111-
struct MissingColumn: Error {
112-
let column: String
113-
}
114-
115-
extension MySQLRow: SQLRow {
116-
public var allColumns: [String] {
117-
self.columnDefinitions.map { $0.name }
118-
}
119-
120-
public func contains(column: String) -> Bool {
121-
self.columnDefinitions.contains { $0.name == column }
122-
}
123-
124-
public func decodeNil(column: String) throws -> Bool {
125-
guard let data = self.column(column) else {
126-
return true
127-
}
128-
return data.buffer == nil
129-
}
130-
131-
public func decode<D>(column: String, as type: D.Type) throws -> D where D : Decodable {
132-
guard let data = self.column(column) else {
133-
throw MissingColumn(column: column)
134-
}
135-
return try MySQLDataDecoder().decode(D.self, from: data)
136-
}
137-
}
138-
139-
public struct SQLRaw: SQLExpression {
140-
public var string: String
141-
public init(_ string: String) {
142-
self.string = string
143-
}
144-
145-
public func serialize(to serializer: inout SQLSerializer) {
146-
serializer.write(self.string)
147-
}
148-
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
public struct MySQLConnectionSource: ConnectionPoolSource {
2+
public let configuration: MySQLConfiguration
3+
4+
public init(configuration: MySQLConfiguration) {
5+
self.configuration = configuration
6+
}
7+
8+
public func makeConnection(logger: Logger, on eventLoop: EventLoop) -> EventLoopFuture<MySQLConnection> {
9+
let address: SocketAddress
10+
do {
11+
address = try self.configuration.address()
12+
} catch {
13+
return eventLoop.makeFailedFuture(error)
14+
}
15+
return MySQLConnection.connect(
16+
to: address,
17+
username: self.configuration.username,
18+
database: self.configuration.database ?? self.configuration.username,
19+
password: self.configuration.password,
20+
tlsConfiguration: self.configuration.tlsConfiguration,
21+
logger: logger,
22+
on: eventLoop
23+
)
24+
}
25+
}
26+
27+
extension MySQLConnection: ConnectionPoolItem { }

Sources/MySQLKit/MySQLDataDecoder.swift

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ extension MySQLData {
2525
}
2626

2727
public struct MySQLDataDecoder {
28-
public init() {}
28+
let json: JSONDecoder
29+
30+
public init(json: JSONDecoder = .init()) {
31+
self.json = json
32+
}
2933

3034
public func decode<T>(_ type: T.Type, from data: MySQLData) throws -> T
3135
where T: Decodable
@@ -40,7 +44,7 @@ public struct MySQLDataDecoder {
4044
}
4145
return value as! T
4246
} else {
43-
return try T.init(from: _Decoder(data: data))
47+
return try T.init(from: _Decoder(data: data, json: self.json))
4448
}
4549
}
4650

@@ -52,22 +56,25 @@ public struct MySQLDataDecoder {
5256
var userInfo: [CodingUserInfoKey : Any] {
5357
return [:]
5458
}
55-
59+
5660
let data: MySQLData
57-
init(data: MySQLData) {
61+
let json: JSONDecoder
62+
63+
init(data: MySQLData, json: JSONDecoder) {
5864
self.data = data
65+
self.json = json
5966
}
6067

6168
func unkeyedContainer() throws -> UnkeyedDecodingContainer {
62-
try JSONDecoder()
69+
try self.json
6370
.decode(DecoderUnwrapper.self, from: self.data.data!)
6471
.decoder.unkeyedContainer()
6572
}
6673

6774
func container<Key>(keyedBy type: Key.Type) throws -> KeyedDecodingContainer<Key>
6875
where Key : CodingKey
6976
{
70-
try JSONDecoder()
77+
try self.json
7178
.decode(DecoderUnwrapper.self, from: self.data.data!)
7279
.decoder.container(keyedBy: Key.self)
7380
}

Sources/MySQLKit/MySQLDataEncoder.swift

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import Foundation
22

33
public struct MySQLDataEncoder {
4-
public init() { }
4+
let json: JSONEncoder
5+
6+
public init(json: JSONEncoder = .init()) {
7+
self.json = json
8+
}
59

610
public func encode(_ value: Encodable) throws -> MySQLData {
711
if let custom = value as? MySQLDataConvertible, let data = custom.mysqlData {
@@ -13,7 +17,7 @@ public struct MySQLDataEncoder {
1317
return data
1418
} else {
1519
var buffer = ByteBufferAllocator().buffer(capacity: 0)
16-
try buffer.writeBytes(JSONEncoder().encode(_Wrapper(value)))
20+
try buffer.writeBytes(self.json.encode(_Wrapper(value)))
1721
return MySQLData(
1822
type: .string,
1923
format: .text,

Sources/MySQLKit/MySQLDatabase+SQL.swift

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
extension MySQLDatabase {
2-
public func sql() -> SQLDatabase {
3-
_MySQLSQLDatabase(database: self)
2+
public func sql(
3+
encoder: MySQLDataEncoder = .init(),
4+
decoder: MySQLDataDecoder = .init()
5+
) -> SQLDatabase {
6+
_MySQLSQLDatabase(database: self, encoder: encoder, decoder: decoder)
47
}
58
}
69

710

811
private struct _MySQLSQLDatabase {
912
let database: MySQLDatabase
13+
let encoder: MySQLDataEncoder
14+
let decoder: MySQLDataDecoder
1015
}
1116

1217
extension _MySQLSQLDatabase: SQLDatabase {
@@ -26,9 +31,9 @@ extension _MySQLSQLDatabase: SQLDatabase {
2631
let (sql, binds) = self.serialize(query)
2732
do {
2833
return try self.database.query(sql, binds.map { encodable in
29-
return try MySQLDataEncoder().encode(encodable)
34+
return try self.encoder.encode(encodable)
3035
}, onRow: { row in
31-
onRow(row)
36+
onRow(row.sql(decoder: self.decoder))
3237
})
3338
} catch {
3439
return self.eventLoop.makeFailedFuture(error)

Sources/MySQLKit/MySQLRow+SQL.swift

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
extension MySQLRow {
2+
public func sql(decoder: MySQLDataDecoder = .init()) -> SQLRow {
3+
_MySQLSQLRow(row: self, decoder: decoder)
4+
}
5+
}
6+
7+
struct MissingColumn: Error {
8+
let column: String
9+
}
10+
11+
private struct _MySQLSQLRow: SQLRow {
12+
let row: MySQLRow
13+
let decoder: MySQLDataDecoder
14+
15+
var allColumns: [String] {
16+
self.row.columnDefinitions.map { $0.name }
17+
}
18+
19+
func contains(column: String) -> Bool {
20+
self.row.columnDefinitions.contains { $0.name == column }
21+
}
22+
23+
func decodeNil(column: String) throws -> Bool {
24+
guard let data = self.row.column(column) else {
25+
return true
26+
}
27+
return data.buffer == nil
28+
}
29+
30+
func decode<D>(column: String, as type: D.Type) throws -> D where D : Decodable {
31+
guard let data = self.row.column(column) else {
32+
throw MissingColumn(column: column)
33+
}
34+
return try self.decoder.decode(D.self, from: data)
35+
}
36+
}

Tests/LinuxMain.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
#error("Please test with `swift test --enable-test-discovery`")
2-

Tests/MySQLKitTests/MySQLKitTests.swift

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,37 +18,86 @@ class MySQLKitTests: XCTestCase {
1818
let name: String?
1919
}
2020

21-
let rows = try self.db.raw("SELECT 1 as `id`, null as `name`")
21+
let rows = try self.sql.raw("SELECT 1 as `id`, null as `name`")
2222
.all(decoding: Person.self).wait()
2323
XCTAssertEqual(rows[0].id, 1)
2424
XCTAssertEqual(rows[0].name, nil)
2525
}
2626

27-
var db: SQLDatabase {
28-
self.connection.sql()
27+
func testCustomJSONCoder() throws {
28+
let encoder = JSONEncoder()
29+
encoder.dateEncodingStrategy = .secondsSince1970
30+
let decoder = JSONDecoder()
31+
decoder.dateDecodingStrategy = .secondsSince1970
32+
let db = self.mysql.sql(encoder: .init(json: encoder), decoder: .init(json: decoder))
33+
34+
struct Foo: Codable, Equatable {
35+
var bar: Bar
36+
}
37+
struct Bar: Codable, Equatable {
38+
var baz: Date
39+
}
40+
41+
try db.create(table: "foo")
42+
.column("bar", type: .custom(SQLRaw("JSON")))
43+
.run().wait()
44+
defer { try! db.drop(table: "foo").ifExists().run().wait() }
45+
46+
let foo = Foo(bar: .init(baz: .init(timeIntervalSince1970: 1337)))
47+
try db.insert(into: "foo").model(foo).run().wait()
48+
49+
let rows = try db.select().columns("*").from("foo").all(decoding: Foo.self).wait()
50+
XCTAssertEqual(rows, [foo])
51+
}
52+
53+
var sql: SQLDatabase {
54+
self.mysql.sql()
2955
}
56+
57+
var mysql: MySQLDatabase {
58+
self.pool.pool(for: self.eventLoopGroup.next())
59+
.database(logger: .init(label: "codes.vapor.mysql"))
60+
}
61+
3062
var benchmark: SQLBenchmarker {
31-
.init(on: self.db)
63+
.init(on: self.sql)
3264
}
3365

3466
var eventLoopGroup: EventLoopGroup!
35-
var connection: MySQLConnection!
67+
var pool: EventLoopGroupConnectionPool<MySQLConnectionSource>!
3668

3769
override func setUpWithError() throws {
3870
try super.setUpWithError()
3971
XCTAssertTrue(isLoggingConfigured)
4072
self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2)
41-
self.connection = try MySQLConnection.test(
42-
on: self.eventLoopGroup.next()
43-
).wait()
44-
_ = try self.connection.simpleQuery("DROP DATABASE vapor_database").wait()
45-
_ = try self.connection.simpleQuery("CREATE DATABASE vapor_database").wait()
46-
_ = try self.connection.simpleQuery("USE vapor_database").wait()
73+
self.pool = .init(
74+
source: .init(configuration: .init(
75+
hostname: env("MYSQL_HOSTNAME") ?? "localhost",
76+
port: 3306,
77+
username: "vapor_username",
78+
password: "vapor_password",
79+
database: "vapor_database",
80+
tlsConfiguration: .forClient(certificateVerification: .none)
81+
)),
82+
maxConnectionsPerEventLoop: 2,
83+
requestTimeout: .seconds(30),
84+
logger: .init(label: "codes.vapor.mysql"),
85+
on: self.eventLoopGroup
86+
)
87+
88+
// Reset database.
89+
_ = try self.mysql.withConnection { conn in
90+
return conn.simpleQuery("DROP DATABASE vapor_database").flatMap { _ in
91+
conn.simpleQuery("CREATE DATABASE vapor_database")
92+
}.flatMap { _ in
93+
conn.simpleQuery("USE vapor_database")
94+
}
95+
}.wait()
4796
}
4897

4998
override func tearDownWithError() throws {
50-
try self.connection?.close().wait()
51-
self.connection = nil
99+
try self.pool.syncShutdownGracefully()
100+
self.pool = nil
52101
try self.eventLoopGroup.syncShutdownGracefully()
53102
self.eventLoopGroup = nil
54103
try super.tearDownWithError()

0 commit comments

Comments
 (0)