Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions embed_ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
const defaultBaseURLOllama = "http://localhost:11434/api"

type ollamaResponse struct {
Embedding []float32 `json:"embedding"`
Embeddings [][]float32 `json:"embeddings"`
}

// NewEmbeddingFuncOllama returns a function that creates embeddings for a text
Expand All @@ -39,16 +39,17 @@ func NewEmbeddingFuncOllama(model string, baseURLOllama string) EmbeddingFunc {
return func(ctx context.Context, text string) ([]float32, error) {
// Prepare the request body.
reqBody, err := json.Marshal(map[string]string{
"model": model,
"prompt": text,
"model": model,
"input": text,
})

if err != nil {
return nil, fmt.Errorf("couldn't marshal request body: %w", err)
}

// Create the request. Creating it with context is important for a timeout
// to be possible, because the client is configured without a timeout.
req, err := http.NewRequestWithContext(ctx, "POST", baseURLOllama+"/embeddings", bytes.NewBuffer(reqBody))
req, err := http.NewRequestWithContext(ctx, "POST", baseURLOllama+"/embed", bytes.NewBuffer(reqBody))
if err != nil {
return nil, fmt.Errorf("couldn't create request: %w", err)
}
Expand Down Expand Up @@ -78,11 +79,11 @@ func NewEmbeddingFuncOllama(model string, baseURLOllama string) EmbeddingFunc {
}

// Check if the response contains embeddings.
if len(embeddingResponse.Embedding) == 0 {
if len(embeddingResponse.Embeddings) == 0 {
return nil, errors.New("no embeddings found in the response")
}

v := embeddingResponse.Embedding
v := embeddingResponse.Embeddings[0]
checkNormalized.Do(func() {
if isNormalized(v) {
checkedNormalized = true
Expand Down
14 changes: 7 additions & 7 deletions embed_ollama_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@ func TestNewEmbeddingFuncOllama(t *testing.T) {
prompt := "hello world"

wantBody, err := json.Marshal(map[string]string{
"model": model,
"prompt": prompt,
"model": model,
"input": prompt,
})
if err != nil {
t.Fatal("unexpected error:", err)
}
wantRes := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
wantRes := [][]float32{{-0.40824828, 0.40824828, 0.81649655}} // normalized version of `{-0.1, 0.1, 0.2}`

// Mock server
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check URL
if !strings.HasSuffix(r.URL.Path, baseURLSuffix+"/embeddings") {
t.Fatal("expected URL", baseURLSuffix+"/embeddings", "got", r.URL.Path)
if !strings.HasSuffix(r.URL.Path, baseURLSuffix+"/embed") {
t.Fatal("expected URL", baseURLSuffix+"/embed", "got", r.URL.Path)
}
// Check method
if r.Method != "POST" {
Expand All @@ -52,7 +52,7 @@ func TestNewEmbeddingFuncOllama(t *testing.T) {

// Write response
resp := ollamaResponse{
Embedding: wantRes,
Embeddings: wantRes,
}
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(resp)
Expand All @@ -70,7 +70,7 @@ func TestNewEmbeddingFuncOllama(t *testing.T) {
if err != nil {
t.Fatal("expected nil, got", err)
}
if slices.Compare(wantRes, res) != 0 {
if slices.Compare(wantRes[0], res) != 0 {
t.Fatal("expected res", wantRes, "got", res)
}
}