Skip to content

Commit f35715f

Browse files
CopilotOhYee
andauthored
Fix MySQL embedder dimension regression in memory_collection._build_mem0_config
Agent-Logs-Url: https://github.com/Serverless-Devs/agentrun-sdk-python/sessions/39e6e1c4-a68e-4386-ab8c-af3c068c7414 Co-authored-by: OhYee <13498329+OhYee@users.noreply.github.com>
1 parent e6d5dc9 commit f35715f

3 files changed

Lines changed: 96 additions & 24 deletions

File tree

agentrun/memory_collection/__memory_collection_async_template.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -476,14 +476,15 @@ async def _build_mem0_config_async(
476476
}
477477

478478
# 从 vector_store_config 中获取向量维度
479-
if (
480-
memory_collection.vector_store_config
481-
and memory_collection.vector_store_config.config
482-
and memory_collection.vector_store_config.config.vector_dimension
483-
):
484-
embedder_config_dict["embedding_dims"] = (
485-
memory_collection.vector_store_config.config.vector_dimension
486-
)
479+
vector_dimension: Optional[int] = None
480+
if memory_collection.vector_store_config:
481+
vsc = memory_collection.vector_store_config
482+
if vsc.config and vsc.config.vector_dimension:
483+
vector_dimension = vsc.config.vector_dimension
484+
elif vsc.mysql_config and vsc.mysql_config.vector_dimension:
485+
vector_dimension = vsc.mysql_config.vector_dimension
486+
if vector_dimension:
487+
embedder_config_dict["embedding_dims"] = vector_dimension
487488

488489
mem0_config["embedder"] = {
489490
"provider": "openai", # mem0 使用 openai 兼容接口

agentrun/memory_collection/memory_collection.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -714,14 +714,15 @@ async def _build_mem0_config_async(
714714
}
715715

716716
# 从 vector_store_config 中获取向量维度
717-
if (
718-
memory_collection.vector_store_config
719-
and memory_collection.vector_store_config.config
720-
and memory_collection.vector_store_config.config.vector_dimension
721-
):
722-
embedder_config_dict["embedding_dims"] = (
723-
memory_collection.vector_store_config.config.vector_dimension
724-
)
717+
vector_dimension: Optional[int] = None
718+
if memory_collection.vector_store_config:
719+
vsc = memory_collection.vector_store_config
720+
if vsc.config and vsc.config.vector_dimension:
721+
vector_dimension = vsc.config.vector_dimension
722+
elif vsc.mysql_config and vsc.mysql_config.vector_dimension:
723+
vector_dimension = vsc.mysql_config.vector_dimension
724+
if vector_dimension:
725+
embedder_config_dict["embedding_dims"] = vector_dimension
725726

726727
mem0_config["embedder"] = {
727728
"provider": "openai", # mem0 使用 openai 兼容接口
@@ -880,14 +881,15 @@ def _build_mem0_config(
880881
}
881882

882883
# 从 vector_store_config 中获取向量维度
883-
if (
884-
memory_collection.vector_store_config
885-
and memory_collection.vector_store_config.config
886-
and memory_collection.vector_store_config.config.vector_dimension
887-
):
888-
embedder_config_dict["embedding_dims"] = (
889-
memory_collection.vector_store_config.config.vector_dimension
890-
)
884+
vector_dimension: Optional[int] = None
885+
if memory_collection.vector_store_config:
886+
vsc = memory_collection.vector_store_config
887+
if vsc.config and vsc.config.vector_dimension:
888+
vector_dimension = vsc.config.vector_dimension
889+
elif vsc.mysql_config and vsc.mysql_config.vector_dimension:
890+
vector_dimension = vsc.mysql_config.vector_dimension
891+
if vector_dimension:
892+
embedder_config_dict["embedding_dims"] = vector_dimension
891893

892894
mem0_config["embedder"] = {
893895
"provider": "openai", # mem0 使用 openai 兼容接口

tests/unittests/memory_collection/test_memory_collection.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,75 @@ def test_build_mem0_config_with_mysql_sync(self, mock_get_credential):
630630
assert vs_config["port"] == 3307
631631
assert vs_config["embedding_model_dims"] == 1024
632632

633+
@patch("agentrun.memory_collection.memory_collection.MemoryCollection._resolve_model_service_config")
634+
@patch("agentrun.credential.Credential.get_by_name")
635+
def test_build_mem0_config_mysql_embedder_dims_sync(
636+
self, mock_get_credential, mock_resolve
637+
):
638+
"""测试 MySQL provider 时 embedder 的 embedding_dims 应从 mysql_config 读取"""
639+
mock_credential = MagicMock()
640+
mock_credential.credential_secret = "test-password"
641+
mock_get_credential.return_value = mock_credential
642+
mock_resolve.return_value = ("https://api.example.com", "sk-fake")
643+
644+
memory_collection = MemoryCollection(
645+
memory_collection_name="t",
646+
vector_store_config=VectorStoreConfig(
647+
provider="alibabacloud_mysql",
648+
mysql_config=VectorStoreConfigMysqlConfig(
649+
host="h",
650+
port=3306,
651+
db_name="d",
652+
user="u",
653+
collection_name="c",
654+
credential_name="cred",
655+
vector_dimension=1024,
656+
),
657+
),
658+
embedder_config=EmbedderConfig(
659+
model_service_name="my-model-svc",
660+
config=EmbedderConfigConfig(model="text-embedding-v3"),
661+
),
662+
)
663+
config = MemoryCollection._build_mem0_config(memory_collection, None, None)
664+
assert config["embedder"]["config"]["embedding_dims"] == 1024
665+
666+
@patch("agentrun.memory_collection.memory_collection.MemoryCollection._resolve_model_service_config_async")
667+
@patch("agentrun.credential.Credential.get_by_name_async")
668+
@pytest.mark.asyncio
669+
async def test_build_mem0_config_mysql_embedder_dims_async(
670+
self, mock_get_credential, mock_resolve
671+
):
672+
"""测试 MySQL provider 时异步 embedder 的 embedding_dims 应从 mysql_config 读取"""
673+
mock_credential = MagicMock()
674+
mock_credential.credential_secret = "test-password"
675+
mock_get_credential.return_value = mock_credential
676+
mock_resolve.return_value = ("https://api.example.com", "sk-fake")
677+
678+
memory_collection = MemoryCollection(
679+
memory_collection_name="t",
680+
vector_store_config=VectorStoreConfig(
681+
provider="alibabacloud_mysql",
682+
mysql_config=VectorStoreConfigMysqlConfig(
683+
host="h",
684+
port=3306,
685+
db_name="d",
686+
user="u",
687+
collection_name="c",
688+
credential_name="cred",
689+
vector_dimension=1024,
690+
),
691+
),
692+
embedder_config=EmbedderConfig(
693+
model_service_name="my-model-svc",
694+
config=EmbedderConfigConfig(model="text-embedding-v3"),
695+
),
696+
)
697+
config = await MemoryCollection._build_mem0_config_async(
698+
memory_collection, None, None
699+
)
700+
assert config["embedder"]["config"]["embedding_dims"] == 1024
701+
633702
@patch("agentrun.credential.Credential.get_by_name_async")
634703
@pytest.mark.asyncio
635704
async def test_build_mem0_config_mysql_default_values(

0 commit comments

Comments
 (0)