diff --git a/src/basic_memory/repository/fastembed_provider.py b/src/basic_memory/repository/fastembed_provider.py index 5ade579e9..3d637aced 100644 --- a/src/basic_memory/repository/fastembed_provider.py +++ b/src/basic_memory/repository/fastembed_provider.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import math from typing import TYPE_CHECKING from loguru import logger @@ -119,10 +120,17 @@ def _embed_batch() -> list[list[float]]: if effective_parallel is not None: embed_kwargs["parallel"] = effective_parallel vectors = list(model.embed(texts, **embed_kwargs)) + # sqlite_search_repository.py uses a distance-to-similarity formula that assumes + # unit-normalized vectors (see the comment on line 65-67 of that file). + # Some models (e.g. multilingual ones) return vectors with norm > 1, so we + # L2-normalize here to satisfy that contract regardless of the chosen model. normalized: list[list[float]] = [] for vector in vectors: - values = vector.tolist() if hasattr(vector, "tolist") else vector - normalized.append([float(value) for value in values]) + values = vector.tolist() if hasattr(vector, "tolist") else list(vector) + norm = math.sqrt(sum(x * x for x in values)) + if norm > 0: + values = [x / norm for x in values] + normalized.append([float(v) for v in values]) return normalized vectors = await asyncio.to_thread(_embed_batch) diff --git a/tests/repository/test_fastembed_provider.py b/tests/repository/test_fastembed_provider.py index bc0aed7ca..a8e073d8c 100644 --- a/tests/repository/test_fastembed_provider.py +++ b/tests/repository/test_fastembed_provider.py @@ -1,6 +1,7 @@ """Tests for FastEmbedEmbeddingProvider.""" import builtins +import math import sys import pytest @@ -148,3 +149,65 @@ async def test_fastembed_provider_parallel_two_passes_multiprocessing(monkeypatc await provider.embed_documents(["parallel enabled"]) assert _StubTextEmbedding.last_embed_kwargs == {"batch_size": 64, "parallel": 2} + + +class _UnormalizedVector: + """Stub vector with norm != 1 (simulates multilingual models like paraphrase-multilingual-*).""" + + def __init__(self, values): + self._values = values + + def tolist(self): + return self._values + + +class _UnnormalizedTextEmbedding: + def __init__(self, model_name: str, **_kwargs): + self.model_name = model_name + + def embed(self, texts: list[str], **_kwargs): + # Return a vector with norm ~= 2.9 (typical for multilingual MiniLM models) + for _ in texts: + yield _UnormalizedVector([1.5, 2.0, 1.0, 0.5]) + + +@pytest.mark.asyncio +async def test_fastembed_provider_l2_normalizes_output_vectors(monkeypatch): + """Returned vectors must be unit-normalized regardless of the raw model output. + + sqlite_search_repository uses a formula that assumes norm == 1. Models such as + paraphrase-multilingual-MiniLM-L12-v2 return vectors with norm ~2.9, which breaks + cosine similarity scoring. The provider must apply L2 normalization before returning. + """ + module = type(sys)("fastembed") + setattr(module, "TextEmbedding", _UnnormalizedTextEmbedding) + monkeypatch.setitem(sys.modules, "fastembed", module) + + provider = FastEmbedEmbeddingProvider(model_name="stub-multilingual", dimensions=4) + result = await provider.embed_documents(["some text"]) + + assert len(result) == 1 + norm = math.sqrt(sum(x * x for x in result[0])) + assert abs(norm - 1.0) < 1e-6, f"Expected unit norm, got {norm}" + + +@pytest.mark.asyncio +async def test_fastembed_provider_zero_vector_does_not_raise(monkeypatch): + """A zero vector from the model must be returned as-is without a division error.""" + + class _ZeroEmbedding: + def __init__(self, model_name: str, **_kwargs): + pass + + def embed(self, texts: list[str], **_kwargs): + for _ in texts: + yield _UnormalizedVector([0.0, 0.0, 0.0, 0.0]) + + module = type(sys)("fastembed") + setattr(module, "TextEmbedding", _ZeroEmbedding) + monkeypatch.setitem(sys.modules, "fastembed", module) + + provider = FastEmbedEmbeddingProvider(model_name="stub-zero", dimensions=4) + result = await provider.embed_documents(["zero vector"]) + + assert result == [[0.0, 0.0, 0.0, 0.0]]