Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ let package = Package(
.package(url: "https://github.com/swift-server/async-http-client.git", from: "1.20.1"),
.package(url: "https://github.com/orlandos-nl/DNSClient.git", from: "2.4.1"),
.package(url: "https://github.com/Bouke/DNS.git", from: "1.2.0"),
.package(url: "https://github.com/apple/containerization.git", exact: Version(stringLiteral: scVersion)),
.package(url: "https://github.com/apple/containerization.git", branch: "main"),
],
targets: [
.executableTarget(
Expand Down
7 changes: 4 additions & 3 deletions Sources/ContainerClient/Core/ClientImage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ extension ClientImage {
})
}

public static func pull(reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil) async throws -> ClientImage {
public static func pull(reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil, maxConcurrentDownloads: Int = 3) async throws -> ClientImage {
let client = newXPCClient()
let request = newRequest(.imagePull)

Expand All @@ -234,6 +234,7 @@ extension ClientImage {

let insecure = try scheme.schemeFor(host: host) == .http
request.set(key: .insecureFlag, value: insecure)
request.set(key: .maxConcurrentDownloads, value: Int64(maxConcurrentDownloads))

var progressUpdateClient: ProgressUpdateClient?
if let progressUpdate {
Expand Down Expand Up @@ -293,7 +294,7 @@ extension ClientImage {
return (digests, size)
}

public static func fetch(reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil) async throws -> ClientImage
public static func fetch(reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil, maxConcurrentDownloads: Int = 3) async throws -> ClientImage
{
do {
let match = try await self.get(reference: reference)
Expand All @@ -307,7 +308,7 @@ extension ClientImage {
guard err.isCode(.notFound) else {
throw err
}
return try await Self.pull(reference: reference, platform: platform, scheme: scheme, progressUpdate: progressUpdate)
return try await Self.pull(reference: reference, platform: platform, scheme: scheme, progressUpdate: progressUpdate, maxConcurrentDownloads: maxConcurrentDownloads)
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions Sources/ContainerClient/Flags.swift
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,15 @@ public struct Flags {
self.disableProgressUpdates = disableProgressUpdates
}

public init(disableProgressUpdates: Bool, maxConcurrentDownloads: Int) {
self.disableProgressUpdates = disableProgressUpdates
self.maxConcurrentDownloads = maxConcurrentDownloads
}

@Flag(name: .long, help: "Disable progress bar updates")
public var disableProgressUpdates = false

@Option(name: .long, help: "Maximum number of concurrent layer downloads (default: 3)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adityaramani Should we use the word "layer" or "blob" here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use just "concurrent downloads"

public var maxConcurrentDownloads: Int = 3
}
}
2 changes: 1 addition & 1 deletion Sources/ContainerCommands/Image/ImagePull.swift
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ extension Application {
let taskManager = ProgressTaskCoordinator()
let fetchTask = await taskManager.startTask()
let image = try await ClientImage.pull(
reference: processedReference, platform: p, scheme: scheme, progressUpdate: ProgressTaskCoordinator.handler(for: fetchTask, from: progress.handler)
reference: processedReference, platform: p, scheme: scheme, progressUpdate: ProgressTaskCoordinator.handler(for: fetchTask, from: progress.handler), maxConcurrentDownloads: self.progressFlags.maxConcurrentDownloads
)

progress.set(description: "Unpacking image")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public enum ImagesServiceXPCKeys: String {
case ociPlatform
case insecureFlag
case garbageCollect
case maxConcurrentDownloads

/// ContentStore
case digest
Expand All @@ -54,6 +55,10 @@ extension XPCMessage {
self.set(key: key.rawValue, value: value)
}

public func set(key: ImagesServiceXPCKeys, value: Int64) {
self.set(key: key.rawValue, value: value)
}

public func set(key: ImagesServiceXPCKeys, value: Data) {
self.set(key: key.rawValue, value: value)
}
Expand All @@ -78,6 +83,10 @@ extension XPCMessage {
self.uint64(key: key.rawValue)
}

public func int64(key: ImagesServiceXPCKeys) -> Int64 {
self.int64(key: key.rawValue)
}

public func bool(key: ImagesServiceXPCKeys) -> Bool {
self.bool(key: key.rawValue)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ public actor ImagesService {
return try await imageStore.list().map { $0.description.fromCZ }
}

public func pull(reference: String, platform: Platform?, insecure: Bool, progressUpdate: ProgressUpdateHandler?) async throws -> ImageDescription {
self.log.info("ImagesService: \(#function) - ref: \(reference), platform: \(String(describing: platform)), insecure: \(insecure)")
public func pull(reference: String, platform: Platform?, insecure: Bool, progressUpdate: ProgressUpdateHandler?, maxConcurrentDownloads: Int = 3) async throws -> ImageDescription {
self.log.info("ImagesService: \(#function) - ref: \(reference), platform: \(String(describing: platform)), insecure: \(insecure), maxConcurrentDownloads: \(maxConcurrentDownloads)")
let img = try await Self.withAuthentication(ref: reference) { auth in
try await self.imageStore.pull(
reference: reference, platform: platform, insecure: insecure, auth: auth, progress: ContainerizationProgressAdapter.handler(from: progressUpdate))
reference: reference, platform: platform, insecure: insecure, auth: auth, progress: ContainerizationProgressAdapter.handler(from: progressUpdate), maxConcurrentDownloads: maxConcurrentDownloads)
}
guard let img else {
throw ContainerizationError(.internalError, message: "Failed to pull image \(reference)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ public struct ImagesServiceHarness: Sendable {
platform = try JSONDecoder().decode(ContainerizationOCI.Platform.self, from: platformData)
}
let insecure = message.bool(key: .insecureFlag)
let maxConcurrentDownloads = message.int64(key: .maxConcurrentDownloads)

let progressUpdateService = ProgressUpdateService(message: message)
let imageDescription = try await service.pull(reference: ref, platform: platform, insecure: insecure, progressUpdate: progressUpdateService?.handler)
let imageDescription = try await service.pull(reference: ref, platform: platform, insecure: insecure, progressUpdate: progressUpdateService?.handler, maxConcurrentDownloads: Int(maxConcurrentDownloads))

let imageData = try JSONEncoder().encode(imageDescription)
let reply = message.reply()
Expand Down
49 changes: 47 additions & 2 deletions Sources/TerminalProgress/ProgressTaskCoordinator.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import Foundation

/// A type that represents a task whose progress is being monitored.
public struct ProgressTask: Sendable, Equatable {
public struct ProgressTask: Sendable, Equatable, Hashable {
private var id = UUID()
private var coordinator: ProgressTaskCoordinator
internal var coordinator: ProgressTaskCoordinator

init(manager: ProgressTaskCoordinator) {
self.coordinator = manager
Expand All @@ -29,6 +29,10 @@ public struct ProgressTask: Sendable, Equatable {
lhs.id == rhs.id
}

public func hash(into hasher: inout Hasher) {
hasher.combine(id)
}

/// Returns `true` if this task is the currently active task, `false` otherwise.
public func isCurrent() async -> Bool {
guard let currentTask = await coordinator.currentTask else {
Expand All @@ -41,6 +45,7 @@ public struct ProgressTask: Sendable, Equatable {
/// A type that coordinates progress tasks to ignore updates from completed tasks.
public actor ProgressTaskCoordinator {
var currentTask: ProgressTask?
var activeTasks: Set<ProgressTask> = []

/// Creates an instance of `ProgressTaskCoordinator`.
public init() {}
Expand All @@ -52,9 +57,36 @@ public actor ProgressTaskCoordinator {
return newTask
}

/// Starts multiple concurrent tasks and returns them.
/// - Parameter count: The number of concurrent tasks to start.
/// - Returns: An array of ProgressTask instances.
public func startConcurrentTasks(count: Int) -> [ProgressTask] {
var tasks: [ProgressTask] = []
for _ in 0..<count {
let task = ProgressTask(manager: self)
tasks.append(task)
activeTasks.insert(task)
}
return tasks
}

/// Marks a specific task as completed and removes it from active tasks.
/// - Parameter task: The task to mark as completed.
public func completeTask(_ task: ProgressTask) {
activeTasks.remove(task)
}

/// Checks if a task is currently active.
/// - Parameter task: The task to check.
/// - Returns: `true` if the task is active, `false` otherwise.
public func isTaskActive(_ task: ProgressTask) -> Bool {
activeTasks.contains(task)
}

/// Performs cleanup when the monitored tasks complete.
public func finish() {
currentTask = nil
activeTasks.removeAll()
}

/// Returns a handler that updates the progress of a given task.
Expand All @@ -69,4 +101,17 @@ public actor ProgressTaskCoordinator {
}
}
}

/// Returns a handler that updates the progress for concurrent tasks.
/// - Parameters:
/// - task: The task whose progress is being updated.
/// - progressUpdate: The handler to invoke when progress updates are received.
public static func concurrentHandler(for task: ProgressTask, from progressUpdate: @escaping ProgressUpdateHandler) -> ProgressUpdateHandler {
{ events in
// Only process updates if the task is still active
if await task.coordinator.isTaskActive(task) {
await progressUpdate(events)
}
}
}
}
109 changes: 109 additions & 0 deletions test_concurrency.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#!/usr/bin/env swift

import Foundation

func testConcurrentDownloads() async throws {
print("Testing concurrent download behavior...\n")

// Track concurrent task count
actor ConcurrencyTracker {
var currentCount = 0
var maxObservedCount = 0
var completedTasks = 0

func taskStarted() {
currentCount += 1
maxObservedCount = max(maxObservedCount, currentCount)
}

func taskCompleted() {
currentCount -= 1
completedTasks += 1
}

func getStats() -> (max: Int, completed: Int) {
return (maxObservedCount, completedTasks)
}

func reset() {
currentCount = 0
maxObservedCount = 0
completedTasks = 0
}
}

let tracker = ConcurrencyTracker()

// Test with different concurrency limits
for maxConcurrent in [1, 3, 6] {
await tracker.reset()

// Simulate downloading 20 layers
let layerCount = 20
let layers = Array(0..<layerCount)

print("Testing maxConcurrent=\(maxConcurrent) with \(layerCount) layers...")

let startTime = Date()

try await withThrowingTaskGroup(of: Void.self) { group in
var iterator = layers.makeIterator()

// Start initial batch based on maxConcurrent
for _ in 0..<maxConcurrent {
if iterator.next() != nil {
group.addTask {
await tracker.taskStarted()
try await Task.sleep(nanoseconds: 10_000_000)
await tracker.taskCompleted()
}
}
}
for try await _ in group {
if iterator.next() != nil {
group.addTask {
await tracker.taskStarted()
try await Task.sleep(nanoseconds: 10_000_000)
await tracker.taskCompleted()
}
}
}
}

let duration = Date().timeIntervalSince(startTime)
let stats = await tracker.getStats()

print(" ✓ Completed: \(stats.completed)/\(layerCount)")
print(" ✓ Max concurrent: \(stats.max)")
print(" ✓ Duration: \(String(format: "%.3f", duration))s")

guard stats.max <= maxConcurrent + 1 else {
throw TestError.concurrencyLimitExceeded
}

guard stats.completed == layerCount else {
throw TestError.incompleteTasks
}

print(" ✅ PASSED\n")
}

print("All tests passed!")
}

enum TestError: Error {
case concurrencyLimitExceeded
case incompleteTasks
}

Task {
do {
try await testConcurrentDownloads()
exit(0)
} catch {
print("Test failed: \(error)")
exit(1)
}
}

RunLoop.main.run()
Loading