Skip to content

How can I initialize multiple instances of different model classes on different GPUs? #572

@ywh-my

Description

@ywh-my

🚀 Feature

Assuming that the user defines two API routes—"/sentiment" for text sentiment analysis and "/generate" for text dialogue generation—both powered by locally deployed AI models. Now, I would like to instantiate 2 sentiment analysis models on GPU 0 and 3 dialogue generation models on GPU 1. The example code is as follows:

sentiment_api = SentimentAnalysisAPI(api_path="/sentiment")


chat_api1 = TextGenerationAPI(api_path="/generate")

server = ls.LitServer([sentiment_api, chat_api1],devices=[[0,0] , [1,1,1]]) 
server.run(port=8881)
###
This means that there are two instances of the SentimentAnalysisAPI model accessible via the "/sentiment" route on GPU 0, and three instances of the TextGenerationAPI model accessible via the "/generate" route on GPU 1. Concurrent processing across these model instances is handled by litserve.
###

However, the latest feature of litserve, "Multiple API endpoints on one port", does not seem to support this functionality. My example code is as follows:

import torch
import torch.nn as nn
import litserve as ls

# 情感分析的简单线性模型
class SimpleSentimentModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(1024, 2)  # 二分类

    def forward(self, x):
        return self.fc(x)

# 文本生成的简单模型(模拟)
class SimpleGenerationModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(1024, 1024)

    def forward(self, x):
        return self.fc(x)

# 情感分析API,路径设置为 /sentiment
class SentimentAnalysisAPI(ls.LitAPI):
    def __init__(self):
        super().__init__(api_path="/sentiment")

    def setup(self, device):
        self.device = device
        self.model = SimpleSentimentModel().to(device)
        self.model.eval()

    def decode_request(self, request: dict):
        return request["text"]

    def predict(self, text):
        input_tensor = torch.randn(1, 1024).to(self.device)
        with torch.no_grad():
            output = self.model(input_tensor)
        return {"SentimentAnalysisAPI": output.tolist()}

# 文本生成API,支持传入不同路径
class TextGenerationAPI(ls.LitAPI):
    def __init__(self, api_path="/generate"):
        super().__init__(api_path=api_path)

    def setup(self, device):
        self.device = device
        self.model = SimpleGenerationModel().to(device)
        self.model.eval()

    def decode_request(self, request: dict):
        return request["prompt"]

    def predict(self, prompt):
        input_tensor = torch.randn(1, 1024).to(self.device)
        with torch.no_grad():
            output = self.model(input_tensor)
        return {self.api_path: output.tolist()}

if __name__ == "__main__":
    # 初始化情感分析模型
    sentiment_api = SentimentAnalysisAPI(api_path="/sentiment")

    # 初始化两个文本生成模型
    chat_api1 = TextGenerationAPI(api_path="/generate1")

    chat_api2 = TextGenerationAPI(api_path="/generate2")

    # 启动服务器
    server = ls.LitServer([sentiment_api, chat_api1, chat_api2],devices=[1])
    server.run(port=8881)

In this case, GPU 1 can instantiate one sentiment analysis model and two text generation models. However, since the text generation models are implemented as two separate API instances, they do not support true concurrency.

Finally, I hope that when users access the "semantic" route, the server can have multiple model instances ready to respond, and that litserve can help manage concurrent processing across these models.
Thank you for reviewing my question.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions