Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions Sources/ContainerClient/Core/ClientImage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,13 @@ extension ClientImage {
})
}

public static func pull(reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil) async throws -> ClientImage {
public static func pull(
reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil, maxConcurrentDownloads: Int = 3
) async throws -> ClientImage {
guard maxConcurrentDownloads > 0 else {
throw ContainerizationError(.invalidArgument, message: "--max-concurrent-downloads must be greater than 0, got \(maxConcurrentDownloads)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't refer to command line arguments inside the client.

Should read "maximum number of concurrent downloads must be..."

}

let client = newXPCClient()
let request = newRequest(.imagePull)

Expand All @@ -234,6 +240,7 @@ extension ClientImage {

let insecure = try scheme.schemeFor(host: host) == .http
request.set(key: .insecureFlag, value: insecure)
request.set(key: .maxConcurrentDownloads, value: Int64(maxConcurrentDownloads))

var progressUpdateClient: ProgressUpdateClient?
if let progressUpdate {
Expand Down Expand Up @@ -313,8 +320,9 @@ extension ClientImage {
return (totalCount: total, activeCount: active, totalSize: size, reclaimableSize: reclaimable)
}

public static func fetch(reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil) async throws -> ClientImage
{
public static func fetch(
reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil, maxConcurrentDownloads: Int = 3
) async throws -> ClientImage {
do {
let match = try await self.get(reference: reference)
if let platform {
Expand All @@ -327,7 +335,7 @@ extension ClientImage {
guard err.isCode(.notFound) else {
throw err
}
return try await Self.pull(reference: reference, platform: platform, scheme: scheme, progressUpdate: progressUpdate)
return try await Self.pull(reference: reference, platform: platform, scheme: scheme, progressUpdate: progressUpdate, maxConcurrentDownloads: maxConcurrentDownloads)
}
}
}
Expand Down
7 changes: 7 additions & 0 deletions Sources/ContainerClient/Flags.swift
Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,11 @@ public struct Flags {
@Option(name: .long, help: ArgumentHelp("Progress type (format: none|ansi)", valueName: "type"))
public var progress: ProgressType = .ansi
}

public struct ImageFetch: ParsableArguments {
public init() {}

@Option(name: .long, help: "Maximum number of concurrent layer downloads (default: 3)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adityaramani Should we use the word "layer" or "blob" here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use just "concurrent downloads"

public var maxConcurrentDownloads: Int = 3
}
}
7 changes: 5 additions & 2 deletions Sources/ContainerClient/Utility.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ public struct Utility {
management: Flags.Management,
resource: Flags.Resource,
registry: Flags.Registry,
imageFetch: Flags.ImageFetch,
progressUpdate: @escaping ProgressUpdateHandler
) async throws -> (ContainerConfiguration, Kernel) {
var requestedPlatform = Parser.platform(os: management.os, arch: management.arch)
Expand All @@ -112,7 +113,8 @@ public struct Utility {
reference: image,
platform: requestedPlatform,
scheme: scheme,
progressUpdate: ProgressTaskCoordinator.handler(for: fetchTask, from: progressUpdate)
progressUpdate: ProgressTaskCoordinator.handler(for: fetchTask, from: progressUpdate),
maxConcurrentDownloads: imageFetch.maxConcurrentDownloads
)

// Unpack a fetched image before use
Expand Down Expand Up @@ -140,7 +142,8 @@ public struct Utility {
let fetchInitTask = await taskManager.startTask()
let initImage = try await ClientImage.fetch(
reference: ClientImage.initImageRef, platform: .current, scheme: scheme,
progressUpdate: ProgressTaskCoordinator.handler(for: fetchInitTask, from: progressUpdate))
progressUpdate: ProgressTaskCoordinator.handler(for: fetchInitTask, from: progressUpdate),
maxConcurrentDownloads: imageFetch.maxConcurrentDownloads)

await progressUpdate([
.setDescription("Unpacking init image"),
Expand Down
4 changes: 4 additions & 0 deletions Sources/ContainerCommands/Container/ContainerCreate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ extension Application {
@OptionGroup(title: "Registry options")
var registryFlags: Flags.Registry

@OptionGroup(title: "Image fetch options")
var imageFetchFlags: Flags.ImageFetch

@OptionGroup
var global: Flags.Global

Expand Down Expand Up @@ -73,6 +76,7 @@ extension Application {
management: managementFlags,
resource: resourceFlags,
registry: registryFlags,
imageFetch: imageFetchFlags,
progressUpdate: progress.handler
)

Expand Down
4 changes: 4 additions & 0 deletions Sources/ContainerCommands/Container/ContainerRun.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ extension Application {
@OptionGroup(title: "Progress options")
var progressFlags: Flags.Progress

@OptionGroup(title: "Image fetch options")
var imageFetchFlags: Flags.ImageFetch

@OptionGroup
var global: Flags.Global

Expand Down Expand Up @@ -97,6 +100,7 @@ extension Application {
management: managementFlags,
resource: resourceFlags,
registry: registryFlags,
imageFetch: imageFetchFlags,
progressUpdate: progress.handler
)

Expand Down
6 changes: 5 additions & 1 deletion Sources/ContainerCommands/Image/ImagePull.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ extension Application {
@OptionGroup
var progressFlags: Flags.Progress

@OptionGroup
var imageFetchFlags: Flags.ImageFetch

@Option(
name: .shortAndLong,
help: "Limit the pull to the specified architecture"
Expand Down Expand Up @@ -100,7 +103,8 @@ extension Application {
let taskManager = ProgressTaskCoordinator()
let fetchTask = await taskManager.startTask()
let image = try await ClientImage.pull(
reference: processedReference, platform: p, scheme: scheme, progressUpdate: ProgressTaskCoordinator.handler(for: fetchTask, from: progress.handler)
reference: processedReference, platform: p, scheme: scheme, progressUpdate: ProgressTaskCoordinator.handler(for: fetchTask, from: progress.handler),
maxConcurrentDownloads: self.imageFetchFlags.maxConcurrentDownloads
)

progress.set(description: "Unpacking image")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public enum ImagesServiceXPCKeys: String {
case ociPlatform
case insecureFlag
case garbageCollect
case maxConcurrentDownloads

/// ContentStore
case digest
Expand All @@ -60,15 +61,15 @@ extension XPCMessage {
self.set(key: key.rawValue, value: value)
}

public func set(key: ImagesServiceXPCKeys, value: Data) {
public func set(key: ImagesServiceXPCKeys, value: Int64) {
self.set(key: key.rawValue, value: value)
}

public func set(key: ImagesServiceXPCKeys, value: Bool) {
public func set(key: ImagesServiceXPCKeys, value: Data) {
self.set(key: key.rawValue, value: value)
}

public func set(key: ImagesServiceXPCKeys, value: Int64) {
public func set(key: ImagesServiceXPCKeys, value: Bool) {
self.set(key: key.rawValue, value: value)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,15 @@ public actor ImagesService {
return try await imageStore.list().map { $0.description.fromCZ }
}

public func pull(reference: String, platform: Platform?, insecure: Bool, progressUpdate: ProgressUpdateHandler?) async throws -> ImageDescription {
self.log.info("ImagesService: \(#function) - ref: \(reference), platform: \(String(describing: platform)), insecure: \(insecure)")
public func pull(reference: String, platform: Platform?, insecure: Bool, progressUpdate: ProgressUpdateHandler?, maxConcurrentDownloads: Int = 3) async throws
-> ImageDescription
{
self.log.info(
"ImagesService: \(#function) - ref: \(reference), platform: \(String(describing: platform)), insecure: \(insecure), maxConcurrentDownloads: \(maxConcurrentDownloads)")
let img = try await Self.withAuthentication(ref: reference) { auth in
try await self.imageStore.pull(
reference: reference, platform: platform, insecure: insecure, auth: auth, progress: ContainerizationProgressAdapter.handler(from: progressUpdate))
reference: reference, platform: platform, insecure: insecure, auth: auth, progress: ContainerizationProgressAdapter.handler(from: progressUpdate),
maxConcurrentDownloads: maxConcurrentDownloads)
}
guard let img else {
throw ContainerizationError(.internalError, message: "failed to pull image \(reference)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@ public struct ImagesServiceHarness: Sendable {
platform = try JSONDecoder().decode(ContainerizationOCI.Platform.self, from: platformData)
}
let insecure = message.bool(key: .insecureFlag)
let maxConcurrentDownloads = message.int64(key: .maxConcurrentDownloads)

let progressUpdateService = ProgressUpdateService(message: message)
let imageDescription = try await service.pull(reference: ref, platform: platform, insecure: insecure, progressUpdate: progressUpdateService?.handler)
let imageDescription = try await service.pull(
reference: ref, platform: platform, insecure: insecure, progressUpdate: progressUpdateService?.handler, maxConcurrentDownloads: Int(maxConcurrentDownloads))

let imageData = try JSONEncoder().encode(imageDescription)
let reply = message.reply()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ public actor SandboxService {
try bundle.createLogFile()

var config = try bundle.configuration

let vmm = VZVirtualMachineManager(
kernel: try bundle.kernel,
initialFilesystem: bundle.initialFilesystem.asMount,
Expand Down
35 changes: 35 additions & 0 deletions Tests/CLITests/Subcommands/Images/TestCLIImages.swift
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,41 @@ extension TestCLIImagesCommand {
}
}

@Test func testMaxConcurrentDownloadsValidation() throws {
// Test that invalid maxConcurrentDownloads value is rejected
let (_, _, error, status) = try run(arguments: [
"image",
"pull",
"--max-concurrent-downloads", "0",
"alpine:latest",
])

#expect(status != 0, "Expected command to fail with maxConcurrentDownloads=0")
#expect(
error.contains("--max-concurrent-downloads must be greater than 0"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix to match the change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made the update

"Expected validation error message in output")
}

@Test func testMaxConcurrentDownloadsFlag() throws {
// Test that the flag is accepted with valid values
do {
try doPull(imageName: alpine, args: ["--max-concurrent-downloads", "1"])
let imagePresent = try isImagePresent(targetImage: alpine)
#expect(imagePresent, "Expected image to be pulled with maxConcurrentDownloads=1")

// Clean up
try? doRemoveImages(images: [alpine])

// Test with higher concurrency
try doPull(imageName: alpine, args: ["--max-concurrent-downloads", "6"])
let imagePresent2 = try isImagePresent(targetImage: alpine)
#expect(imagePresent2, "Expected image to be pulled with maxConcurrentDownloads=6")
} catch {
Issue.record("failed to pull image with maxConcurrentDownloads flag: \(error)")
return
}
}

@Test func testImageSaveAndLoadStdinStdout() throws {
do {
// 1. pull image
Expand Down