Skip to content

Commit cda4732

Browse files
authored
Merge pull request #169 from vapor/dbkit-gm
dbkit 1.0.0 gm
2 parents 02edf2b + e93c760 commit cda4732

13 files changed

+178
-47
lines changed

Package.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ let package = Package(
1414
.package(url: "https://github.com/vapor/crypto.git", from: "3.0.0-rc.2"),
1515

1616
// 🗄 Core services for creating database integrations.
17-
.package(url: "https://github.com/vapor/database-kit.git", from: "1.0.0-rc.2"),
17+
.package(url: "https://github.com/vapor/database-kit.git", from: "1.0.0"),
1818

1919
// 📦 Dependency injection / inversion of control framework.
2020
.package(url: "https://github.com/vapor/service.git", from: "1.0.0-rc.2"),

Sources/MySQL/Connection/MySQLConnection+Authenticate.swift

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@ extension MySQLConnection {
2222
guard let handshake = handshake else {
2323
throw MySQLError(identifier: "handshake", reason: "Handshake required for auth response.", source: .capture())
2424
}
25-
let authPlugin = handshake.authPluginName ?? "none"
25+
let authPlugin = handshake.authPluginName
2626
let authResponse: Data
2727
switch authPlugin {
28-
case "mysql_native_password":
28+
case .some("mysql_native_password"), .none:
29+
guard handshake.capabilities.get(CLIENT_SECURE_CONNECTION) else {
30+
throw MySQLError(identifier: "authproto", reason: "Pre-4.1 auth protocol is not supported or safe.", source: .capture())
31+
}
2932
guard let password = password else {
3033
throw MySQLError(identifier: "password", reason: "Password required for auth plugin.", source: .capture())
3134
}
@@ -40,7 +43,7 @@ extension MySQLConnection {
4043
hash[i] = hash[i] ^ passwordHash[i]
4144
}
4245
authResponse = hash
43-
default: throw MySQLError(identifier: "authPlugin", reason: "Unsupported auth plugin: \(authPlugin)", source: .capture())
46+
default: throw MySQLError(identifier: "authPlugin", reason: "Unsupported auth plugin: \(authPlugin ?? "<none>")", source: .capture())
4447
}
4548
let response = MySQLHandshakeResponse41(
4649
capabilities: [
@@ -55,7 +58,7 @@ extension MySQLConnection {
5558
username: username,
5659
authResponse: authResponse,
5760
database: database,
58-
authPluginName: authPlugin
61+
authPluginName: authPlugin ?? "<none>"
5962
)
6063
return self.send([.handshakeResponse41(response)]) { message in
6164
switch message {

Sources/MySQL/Connection/MySQLConnection.swift

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,17 @@ import Service
77

88
/// A MySQL frontend client.
99
public final class MySQLConnection: BasicWorker, DatabaseConnection {
10-
/// See `Worker.eventLoop`
10+
/// See `Worker`.
1111
public var eventLoop: EventLoop {
1212
return channel.eventLoop
1313
}
1414

15+
/// See `DatabaseConnection`.
16+
public var isClosed: Bool
17+
18+
/// See `Extendable`
19+
public var extend: Extend
20+
1521
/// Handles enqueued redis commands and responses.
1622
private let queue: QueueHandler<MySQLPacket, MySQLPacket>
1723

@@ -24,39 +30,68 @@ public final class MySQLConnection: BasicWorker, DatabaseConnection {
2430
/// The current query running, if one exists.
2531
private var pipeline: Future<Void>
2632

27-
/// See `Extendable.extend`
28-
public var extend: Extend
33+
/// Currently running `send(...)`.
34+
private var currentSend: Promise<Void>?
2935

3036
/// Creates a new MySQL client with the provided MySQL packet queue and channel.
3137
init(queue: QueueHandler<MySQLPacket, MySQLPacket>, channel: Channel) {
3238
self.queue = queue
3339
self.channel = channel
3440
self.pipeline = Future.map(on: channel.eventLoop) { }
3541
self.extend = [:]
42+
self.isClosed = false
43+
44+
// when the channel closes, set isClosed to true and fail any
45+
// currently running calls to `send(...)`.
46+
channel.closeFuture.always {
47+
self.isClosed = true
48+
if let current = self.currentSend {
49+
current.fail(error: closeError)
50+
}
51+
}
3652
}
3753

3854
/// Sends `MySQLPacket` to the server.
3955
internal func send(_ messages: [MySQLPacket], onResponse: @escaping (MySQLPacket) throws -> Bool) -> Future<Void> {
40-
return queue.enqueue(messages) { message in
56+
// if currentSend is not nil, previous send has not completed
57+
assert(currentSend == nil, "Attempting to call `send(...)` again before previous invocation has completed.")
58+
59+
// if the connection is closed, fail immidiately
60+
guard !isClosed else {
61+
return eventLoop.newFailedFuture(error: closeError)
62+
}
63+
64+
// create a new promise and store it
65+
let promise = eventLoop.newPromise(Void.self)
66+
currentSend = promise
67+
68+
// cascade this enqueue to the newly created promise
69+
queue.enqueue(messages) { message in
4170
switch message {
4271
case .err(let err): throw err.makeError(source: .capture())
4372
default: return try onResponse(message)
4473
}
45-
}
74+
}.cascade(promise: promise)
75+
76+
// when the promise completes, remove the reference to it
77+
promise.futureResult.always { self.currentSend = nil }
78+
79+
// return the promise's future result (same as `queue.enqueue`)
80+
return promise.futureResult
4681
}
4782

4883
/// Submits an async task to be pipelined.
4984
internal func operation(_ work: @escaping () -> Future<Void>) -> Future<Void> {
50-
/// perform this work when the current pipeline future is completed
85+
// perform this work when the current pipeline future is completed
5186
let new = pipeline.then(work)
5287

53-
/// append this work to the pipeline, discarding errors as the pipeline
54-
//// does not care about them
88+
// append this work to the pipeline, discarding errors as the pipeline
89+
// does not care about them
5590
pipeline = new.catchMap { err in
5691
return ()
5792
}
5893

59-
/// return the newly enqueued work's future result
94+
// return the newly enqueued work's future result
6095
return new
6196
}
6297

@@ -65,3 +100,6 @@ public final class MySQLConnection: BasicWorker, DatabaseConnection {
65100
channel.close(promise: nil)
66101
}
67102
}
103+
104+
/// Error to throw if the connection has closed.
105+
private let closeError = MySQLError(identifier: "closed", reason: "Connection is closed.", source: .capture())

Sources/MySQL/Connection/MySQLData.swift

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,16 @@ public struct MySQLData: Equatable {
5454
let storage: MySQLBinaryDataStorage?
5555

5656
if let integer = integer {
57-
if I.isSigned {
58-
storage = .integer8(numericCast(integer))
59-
} else {
60-
storage = .uinteger8(numericCast(integer))
57+
switch (I.bitWidth, I.isSigned) {
58+
case ( 8, true): storage = .integer1(numericCast(integer))
59+
case ( 8, false): storage = .uinteger1(numericCast(integer))
60+
case (16, true): storage = .integer2(numericCast(integer))
61+
case (16, false): storage = .uinteger2(numericCast(integer))
62+
case (32, true): storage = .integer4(numericCast(integer))
63+
case (32, false): storage = .uinteger4(numericCast(integer))
64+
case (64, true): storage = .integer8(numericCast(integer))
65+
case (64, false): storage = .uinteger8(numericCast(integer))
66+
default: fatalError("Unsupported bit-width: \(I.bitWidth)")
6167
}
6268
} else {
6369
storage = nil

Sources/MySQL/Database/MySQLDatabase.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ public final class MySQLDatabase: Database {
1111
self.config = config
1212
}
1313

14-
/// See `Database.makeConnection()`
15-
public func makeConnection(on worker: Worker) -> Future<MySQLConnection> {
14+
/// See `Database`
15+
public func newConnection(on worker: Worker) -> Future<MySQLConnection> {
1616
let config = self.config
1717
return Future.flatMap(on: worker) {
1818
return try MySQLConnection.connect(hostname: config.hostname, port: config.port, on: worker) { error in

Sources/MySQL/Database/MySQLLogger.swift

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@ public protocol MySQLLogger {
77
}
88

99
extension DatabaseLogger: MySQLLogger {
10-
/// See MySQLLogger.log
1110
public func log(query: String) {
12-
let log = DatabaseLog(query: query)
13-
record(log: log)
11+
record(query: query, values: [])
1412
}
1513
}
16-

Sources/MySQL/Database/MySQLProvider.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ public final class MySQLProvider: Provider {
1313
try services.register(DatabaseKitProvider())
1414
services.register(MySQLDatabaseConfig.self)
1515
services.register(MySQLDatabase.self)
16-
var databases = DatabaseConfig()
16+
var databases = DatabasesConfig()
1717
databases.add(database: MySQLDatabase.self, as: .mysql)
1818
services.register(databases)
1919
}

Sources/MySQL/Pipeline/MySQLConnectionSession.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ enum MySQLStatementProtocolState {
8080

8181
case waitingExecute
8282
case rowColumns(columns: [MySQLColumnDefinition41], remaining: Int)
83+
case rowColumnsDone(columns: [MySQLColumnDefinition41])
8384
/// ProtocolBinary::ResultsetRow until eof
8485
case rows(columns: [MySQLColumnDefinition41])
8586
}

Sources/MySQL/Pipeline/MySQLPacketDecoder.swift

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ final class MySQLPacketDecoder: ByteToMessageDecoder {
1717
self.session = session
1818
}
1919

20+
func channelInactive(ctx: ChannelHandlerContext) {
21+
cumulationBuffer = nil
22+
ctx.fireChannelInactive()
23+
}
24+
2025
/// Decode from a `ByteBuffer`. This method will be called till either the input
2126
/// `ByteBuffer` has nothing to read left or `DecodingState.needMoreData` is returned.
2227
///
@@ -52,7 +57,7 @@ final class MySQLPacketDecoder: ByteToMessageDecoder {
5257
}
5358

5459
/// Decode's an OK, ERR, or EOF packet
55-
func decodeBasicPacket(ctx: ChannelHandlerContext, buffer: inout ByteBuffer, capabilities: MySQLCapabilities) throws -> DecodingState {
60+
func decodeBasicPacket(ctx: ChannelHandlerContext, buffer: inout ByteBuffer, capabilities: MySQLCapabilities, forwarding: Bool = true) throws -> DecodingState {
5661
guard let length = try buffer.checkPacketLength(source: .capture()) else {
5762
return .needMoreData
5863
}
@@ -84,7 +89,9 @@ final class MySQLPacketDecoder: ByteToMessageDecoder {
8489
}
8590

8691
session.incrementSequenceID()
87-
ctx.fireChannelRead(wrapInboundOut(packet))
92+
if forwarding {
93+
ctx.fireChannelRead(wrapInboundOut(packet))
94+
}
8895

8996
return .continue
9097
}
@@ -96,6 +103,15 @@ final class MySQLPacketDecoder: ByteToMessageDecoder {
96103
textState: MySQLTextProtocolState,
97104
capabilities: MySQLCapabilities
98105
) throws -> DecodingState {
106+
if !capabilities.get(CLIENT_DEPRECATE_EOF) {
107+
// check for error or OK packet
108+
let peek = buffer.peekInteger(as: Byte.self, skipping: 4)
109+
switch peek {
110+
case 0xFE: return try decodeBasicPacket(ctx: ctx, buffer: &buffer, capabilities: capabilities, forwarding: false)
111+
default: break
112+
}
113+
}
114+
99115
switch textState {
100116
case .waiting:
101117
// check for error or OK packet
@@ -200,7 +216,7 @@ final class MySQLPacketDecoder: ByteToMessageDecoder {
200216
}
201217

202218
if !capabilities.get(CLIENT_DEPRECATE_EOF) {
203-
return try decodeBasicPacket(ctx: ctx, buffer: &buffer, capabilities: capabilities)
219+
return try decodeBasicPacket(ctx: ctx, buffer: &buffer, capabilities: capabilities, forwarding: false)
204220
}
205221
case .columns(var remaining):
206222
guard let _ = try buffer.checkPacketLength(source: .capture()) else {
@@ -219,7 +235,7 @@ final class MySQLPacketDecoder: ByteToMessageDecoder {
219235
}
220236
case .columnsDone:
221237
if !capabilities.get(CLIENT_DEPRECATE_EOF) {
222-
return try decodeBasicPacket(ctx: ctx, buffer: &buffer, capabilities: capabilities)
238+
return try decodeBasicPacket(ctx: ctx, buffer: &buffer, capabilities: capabilities, forwarding: false)
223239
}
224240
case .waitingExecute:
225241
// check for error or OK packet
@@ -247,10 +263,17 @@ final class MySQLPacketDecoder: ByteToMessageDecoder {
247263
ctx.fireChannelRead(wrapInboundOut(.columnDefinition41(column)))
248264
remaining -= 1
249265
if remaining == 0 {
250-
session.connectionState = .statement(.rows(columns: columns))
266+
session.connectionState = .statement(.rowColumnsDone(columns: columns))
251267
} else {
252268
session.connectionState = .statement(.rowColumns(columns: columns, remaining: remaining))
253269
}
270+
case .rowColumnsDone(let columns):
271+
if !capabilities.get(CLIENT_DEPRECATE_EOF) {
272+
let result = try decodeBasicPacket(ctx: ctx, buffer: &buffer, capabilities: capabilities, forwarding: false)
273+
session.connectionState = .statement(.rows(columns: columns))
274+
return result
275+
}
276+
session.connectionState = .statement(.rows(columns: columns))
254277
case .rows(let columns):
255278
if buffer.peekInteger(as: Byte.self, skipping: 4) == 0xFE {
256279
session.connectionState = .none

Sources/MySQL/Protocol/MySQLHandshakeV10.swift

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,17 @@ struct MySQLHandshakeV10 {
7171
let reserved = try bytes.requireBytes(length: 10, source: .capture())
7272
assert(reserved == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
7373

74-
if capabilities.get(CLIENT_SECURE_CONNECTION), authPluginDataLength > 0 {
75-
let len = max(13, authPluginDataLength - 8)
76-
let authPluginDataPart2 = try bytes.requireBytes(length: numericCast(len), source: .capture())
77-
78-
self.authPluginData = Data(authPluginDataPart1 + authPluginDataPart2)
74+
if capabilities.get(CLIENT_SECURE_CONNECTION) {
75+
if capabilities.get(CLIENT_PLUGIN_AUTH) {
76+
let len = max(13, authPluginDataLength - 8)
77+
let authPluginDataPart2 = try bytes.requireBytes(length: numericCast(len), source: .capture())
78+
self.authPluginData = Data(authPluginDataPart1 + authPluginDataPart2)
79+
} else {
80+
let authPluginDataPart2 = try bytes.requireBytes(length: 12, source: .capture())
81+
self.authPluginData = Data(authPluginDataPart1 + authPluginDataPart2)
82+
let filler: Byte = try bytes.requireInteger(source: .capture())
83+
assert(filler == 0x00)
84+
}
7985
} else {
8086
self.authPluginData = Data(authPluginDataPart1)
8187
}

0 commit comments

Comments
 (0)