Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/basic_memory/repository/fastembed_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import asyncio
import math
from typing import TYPE_CHECKING

from loguru import logger
Expand Down Expand Up @@ -119,10 +120,17 @@ def _embed_batch() -> list[list[float]]:
if effective_parallel is not None:
embed_kwargs["parallel"] = effective_parallel
vectors = list(model.embed(texts, **embed_kwargs))
# sqlite_search_repository.py uses a distance-to-similarity formula that assumes
# unit-normalized vectors (see the comment on line 65-67 of that file).
# Some models (e.g. multilingual ones) return vectors with norm > 1, so we
# L2-normalize here to satisfy that contract regardless of the chosen model.
normalized: list[list[float]] = []
for vector in vectors:
values = vector.tolist() if hasattr(vector, "tolist") else vector
normalized.append([float(value) for value in values])
values = vector.tolist() if hasattr(vector, "tolist") else list(vector)
norm = math.sqrt(sum(x * x for x in values))
if norm > 0:
values = [x / norm for x in values]
normalized.append([float(v) for v in values])
return normalized

vectors = await asyncio.to_thread(_embed_batch)
Expand Down
63 changes: 63 additions & 0 deletions tests/repository/test_fastembed_provider.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for FastEmbedEmbeddingProvider."""

import builtins
import math
import sys

import pytest
Expand Down Expand Up @@ -148,3 +149,65 @@ async def test_fastembed_provider_parallel_two_passes_multiprocessing(monkeypatc
await provider.embed_documents(["parallel enabled"])

assert _StubTextEmbedding.last_embed_kwargs == {"batch_size": 64, "parallel": 2}


class _UnormalizedVector:
"""Stub vector with norm != 1 (simulates multilingual models like paraphrase-multilingual-*)."""

def __init__(self, values):
self._values = values

def tolist(self):
return self._values


class _UnnormalizedTextEmbedding:
def __init__(self, model_name: str, **_kwargs):
self.model_name = model_name

def embed(self, texts: list[str], **_kwargs):
# Return a vector with norm ~= 2.9 (typical for multilingual MiniLM models)
for _ in texts:
yield _UnormalizedVector([1.5, 2.0, 1.0, 0.5])


@pytest.mark.asyncio
async def test_fastembed_provider_l2_normalizes_output_vectors(monkeypatch):
"""Returned vectors must be unit-normalized regardless of the raw model output.

sqlite_search_repository uses a formula that assumes norm == 1. Models such as
paraphrase-multilingual-MiniLM-L12-v2 return vectors with norm ~2.9, which breaks
cosine similarity scoring. The provider must apply L2 normalization before returning.
"""
module = type(sys)("fastembed")
setattr(module, "TextEmbedding", _UnnormalizedTextEmbedding)
monkeypatch.setitem(sys.modules, "fastembed", module)

provider = FastEmbedEmbeddingProvider(model_name="stub-multilingual", dimensions=4)
result = await provider.embed_documents(["some text"])

assert len(result) == 1
norm = math.sqrt(sum(x * x for x in result[0]))
assert abs(norm - 1.0) < 1e-6, f"Expected unit norm, got {norm}"


@pytest.mark.asyncio
async def test_fastembed_provider_zero_vector_does_not_raise(monkeypatch):
"""A zero vector from the model must be returned as-is without a division error."""

class _ZeroEmbedding:
def __init__(self, model_name: str, **_kwargs):
pass

def embed(self, texts: list[str], **_kwargs):
for _ in texts:
yield _UnormalizedVector([0.0, 0.0, 0.0, 0.0])

module = type(sys)("fastembed")
setattr(module, "TextEmbedding", _ZeroEmbedding)
monkeypatch.setitem(sys.modules, "fastembed", module)

provider = FastEmbedEmbeddingProvider(model_name="stub-zero", dimensions=4)
result = await provider.embed_documents(["zero vector"])

assert result == [[0.0, 0.0, 0.0, 0.0]]
Loading