import argparse import tcvectordb import sys import time import ast import pandas as pd from datetime import datetime import numpy as np import json import copy from multiprocessing import Process, Queue from tcvectordb import RPCVectorDBClient from tcvectordb.model.collection import Collection from tcvectordb.model.enum import FieldType from tcvectordb.model.index import VectorIndex RUN_SCOPES = {} UN_LIMIT = sys.maxsize def get_client(connect_url, connect_key): return RPCVectorDBClient( url=RUN_SCOPES.get(connect_url), username='root', key=RUN_SCOPES.get(connect_key), timeout=30 ) def get_source_client(): return get_client('source_connect_url', 'source_connect_key') def get_target_client(): return get_client('target_connect_url', 'target_connect_key') def str2bool(v) -> bool: if isinstance(v, bool): return v return v in ("True", "true") def init_args(): py_parser = argparse.ArgumentParser(add_help=False) py_parser.add_argument("--source_connect_url", type=str, required=True) py_parser.add_argument("--source_connect_key", type=str, required=True) py_parser.add_argument("--source_db", type=str, required=True) py_parser.add_argument("--source_collection", type=str, required=True) py_parser.add_argument("--filter", type=str, required=False) py_parser.add_argument("--offset", type=int, default=0) py_parser.add_argument("--target_connect_url", type=str, required=True) py_parser.add_argument("--target_connect_key", type=str, required=True) py_parser.add_argument("--target_db", type=str, required=True) py_parser.add_argument("--target_collection", type=str, required=True) py_parser.add_argument("--parallel", type=int, default=30) py_parser.add_argument("--read_page_size", type=int, default=1000) py_parser.add_argument("--write_batch_size", type=int, default=-1) py_parser.add_argument("--field_mappings", type=str, default="") py_parser.add_argument("--write_build_index", type=str2bool, default=True) py_parser.add_argument("--rebuild_index_after_write", type=str2bool, default=False) args, _ = py_parser.parse_known_args() global RUN_SCOPES RUN_SCOPES = vars(args) print_run_scopes() def print_simple_line(): print('-------------------------------------------') def print_run_scopes(): run_scope = copy.copy(RUN_SCOPES) run_scope['source_connect_key'] = "******" run_scope['target_connect_key'] = "******" print_simple_line() print(f'[{datetime.now()}] [init args]\n {json.dumps(run_scope, ensure_ascii=False, indent=2)}') # 迁移前预处理 # 预检查 # 将字段映射转换为迁移过程中便于使用的结构 # 测试1条数据从源端到目的端的迁移 def pre_migrate(): param_check() target_dict = pre_check() try_migrate_first_row() compute_write_batch_size(target_dict) def compute_write_batch_size(target_dict: dict): batch_size = RUN_SCOPES['write_batch_size'] if RUN_SCOPES['write_batch_size'] < 0: dimensions = 0 for _, index in target_dict.items(): if isinstance(index, VectorIndex): dimensions += index.dimension if dimensions < 64: dimensions = 64 batch_size = int((4096 / dimensions) * 25) if batch_size < 10: batch_size = 10 elif batch_size > 1000: batch_size = 1000 RUN_SCOPES['write_batch_size'] = batch_size print(f"[{datetime.now()}] [upsert batch size]: {batch_size}, all dimensions: {dimensions}.") else: print(f"[{datetime.now()}] [upsert batch size]: {batch_size}.") # 测试1条数据从源端到目的端 def try_migrate_first_row(): print(f"[{datetime.now()}] [pre_check] try migrate first row.") source_db = RUN_SCOPES.get('source_db') source_collection = RUN_SCOPES.get('source_collection') source_client = get_source_client() target_db = RUN_SCOPES.get('target_db') target_collection = RUN_SCOPES.get('target_collection') write_build_index = RUN_SCOPES.get('write_build_index') target_client = get_target_client() try: first_row = get_first_row(source_client, source_db, source_collection) write_into_targets(target_client, target_db, target_collection, first_row, write_build_index, 3) print(f"[{datetime.now()}] [pre_check] The first data has been successfully written.") finally: source_client.close() target_client.close() def write_into_targets(client: RPCVectorDBClient, db_name, collection, rows, write_build_index: bool, timeout: float = None): running_mappings = RUN_SCOPES.get('running_mappings') upsert_rows = [] if not running_mappings: upsert_rows = rows else: for row in rows: upsert_row = {} upsert_rows.append(upsert_row) for key, value in row.items(): if key == 'score': continue target_field = running_mappings.get(key, None) target_key = target_field.get('target_key') if target_field else key upsert_row[target_key] = value client.upsert( database_name=db_name, collection_name=collection, documents=upsert_rows, build_index=write_build_index, timeout=timeout ) def get_first_row(client: RPCVectorDBClient, db_name, collection): query_filter: str = RUN_SCOPES.get("filter", None) offset = RUN_SCOPES.get('offset') result = client.query( database_name=db_name, collection_name=collection, limit=1, offset=offset, filter=query_filter, retrieve_vector=True ) # 源端没数据或经过过滤后没数据 if not result or not result[0]: if not query_filter: print(f"[{datetime.now()}] [pre_check] [NoDataFoundError] In the source database: '{db_name}.{collection}' " "has no data and does not need to migrate any data") else: print( f"[{datetime.now()}] [pre_check] [NoDataFoundError] In the source database: '{db_name}.{collection}' . " f"After filtering under the condition of '{query_filter}', " "there is no data needs to be migrated") exit(0) return result def is_same_address(): return RUN_SCOPES.get('source_connect_url') == RUN_SCOPES.get('target_connect_url') def is_same_db(): return RUN_SCOPES.get('source_db') == RUN_SCOPES.get('target_db') def is_same_collection(): return RUN_SCOPES.get('source_collection') == RUN_SCOPES.get('target_collection') # 参数基本检查 # 源和目的不能完全相同的库和表 def param_check(): print(f"[{datetime.now()}] [Check prams] start.") if is_same_address() and is_same_db() and is_same_collection(): print(f"[{datetime.now()}] [Check prams] [Error] Source and target have the same address, database, " f"and table name. Please modify them to proceed with the migration.") exit(0) print(f"[{datetime.now()}] [Check prams] end.") # 预检查迁移可行性,并将字段映射转换为迁移过程中便于使用的方式 # 1. 检查源、目标均可访问,能获取对应表和元数据 # 2. 源的索引字段类型与目标端的索引字段类型能对应上(如果有field_mapping则根据field_mapping映射,否则同名映射) # 3. 构建传输时的实际映射关系,传输时直接使用,非索引字段schemaless字段也会遵循field_mapping规则,但不检查数据类型 def pre_check() -> dict: print(f'[{datetime.now()}] [field mappings] begin.\n') source_dict = get_source_dict() target_dict = get_target_dict() field_mappings = get_field_mapping_obj() running_mappings = {} simple_mappings = [] for source_key, source_field in source_dict.items(): target_key = field_mappings.get(source_key, source_key) target_field = target_dict.get(target_key, None) # 检查目标端索引字段是否存在 if not target_field: print(f"[{datetime.now()}] [NotFoundIndexError] Cannot find index field '{target_key}'." f"Please verify if the field_mappings settings are correct and " f"if the source and destination index fields are consistent.") # 检查索引字段类型是否一致 if source_field.field_type != target_field.field_type: print(f"[{datetime.now()}] [TypeError] Inconsistent data types.'." f"The type of the source field '{source_key}' is {source_field.field_type}, " f"while the type of the destination field {target_key} is {target_field.field_type}.") if isinstance(source_field, VectorIndex): if source_field.dimension != target_field.dimension: print(f"[{datetime.now()}] [DimensionError] Cannot migrate vectors - " f"source field '{source_field.name}' ({source_field.dimension}) " f"does not match target field '{target_field.name}' ({target_field.dimension}) dimensionality.'.") simple_mappings.append({ 'source_field': source_key, 'target_field': target_key, 'data_type': source_field.field_type.value }) running_mappings[source_key] = { 'target_key': target_key, 'type': 'index_field', 'source_field': source_field, 'target_field': target_field } # 补齐schemaless字段映射 for source_key, target_key in field_mappings.items(): if source_key not in running_mappings: running_mappings[source_key] = { 'target_key': target_key, 'type': 'schemaless' } simple_mappings.append({ 'source_field': source_key, 'target_field': target_key, 'data_type': 'any' }) print(f'[{datetime.now()}] [running field mappings]:\n ' f'{json.dumps(simple_mappings, ensure_ascii=False, indent=2)}') print(f'[{datetime.now()}] [field mappings] end.\n') RUN_SCOPES['running_mappings'] = running_mappings print_simple_line() return target_dict # 获取源端表的结构 def get_source_dict() -> dict: source_client = get_source_client() try: return get_collection_dict( source_client, RUN_SCOPES.get('source_db'), RUN_SCOPES.get('source_collection') ) finally: source_client.close() # 获取目标端表的结构 def get_target_dict() -> dict: target_client = get_target_client() try: return get_collection_dict( get_target_client(), RUN_SCOPES.get('target_db'), RUN_SCOPES.get('target_collection') ) finally: target_client.close() # 获取数据库中的表的结构,转换为dict便于程序使用 def get_collection_dict(client: RPCVectorDBClient, db_name: str, collection: str) -> dict: collection: Collection = client.describe_collection( database_name=db_name, collection_name=collection ) result = {} for index in collection.indexes: result[index.name] = index return result # 将参数中的field mapping转换为dict结构,便于程序使用 def get_field_mapping_obj() -> dict: field_mappings = RUN_SCOPES.get('field_mappings') if not field_mappings or len(str(field_mappings).strip()) == 0: return {} result = {} field_tmp_mappings = str(field_mappings).strip().split(",") for mapping in field_tmp_mappings: map_array = mapping.strip().split("=", 1) if len(map_array) != 2: print(f"[{datetime.now()}] [Error] Incorrect field_mapping setting: {field_mappings}, " f"the correct example is: pk=id,vec=vector") exit(0) result[map_array[0].strip()] = map_array[1].strip() print(f'[{datetime.now()}] [set field mappings]\n {json.dumps(result, ensure_ascii=False, indent=2)}') return result # 生产者从源端读取数据数据 def reader_process(queue, parallel, run_scopes): global RUN_SCOPES RUN_SCOPES = run_scopes.copy() client = get_source_client() db_name = RUN_SCOPES.get("source_db") collection = RUN_SCOPES.get("source_collection") limit = RUN_SCOPES.get("read_page_size") query_filter = run_scopes.get('filter') offset = run_scopes.get('offset') produced = 0 result = client.query( database_name=db_name, collection_name=collection, offset=offset, limit=limit, filter=query_filter, retrieve_vector=True ) while len(result) > 0: produced += len(result) print(f"[{datetime.now()}] [Reader]: queried data size = {len(result)}, produced = {produced}.") queue.put(result) offset += limit result = client.query( database_name=db_name, collection_name=collection, offset=offset, limit=limit, filter=query_filter, retrieve_vector=True ) # 通知各并发结束执行 for i in range(parallel): queue.put(None) def writer_process(p_index, queue: Queue, result_queue: Queue, run_scope): print(f"[{datetime.now()}] [Writer-{p_index}] started.") global RUN_SCOPES RUN_SCOPES = run_scope.copy() db = run_scope.get('target_db') collection = run_scope.get('target_collection') build_index = run_scope.get('write_build_index') write_batch_size = run_scope.get('write_batch_size') written_rows = 0 start_time = time.perf_counter() client = get_target_client() try: while True: result_batch = queue.get() if result_batch: write_count = len(result_batch) written_rows += write_count if write_count <= write_batch_size: write_into_targets(client, db, collection, result_batch, build_index) else: begin = 0 while begin < write_count: end = begin + write_batch_size write_into_targets(client, db, collection, result_batch[begin:end], build_index) begin = end duration = time.perf_counter() - start_time print(f"[{datetime.now()}] [Writer-{p_index}] written {written_rows} rows, duration {duration:.2f}s.") else: break finally: client.close() # writer process end duration = time.perf_counter() - start_time print(f"[{datetime.now()}] [Writer-{p_index}] [completed]. written {written_rows} rows, duration {duration:.2f}s.") result_queue.put({ 'process': p_index, 'duration': duration, 'written_rows': written_rows }) def after_writer(queue_result: Queue, parallel): total_rows = 0 max_duration = 0 rebuild_index_after_write = RUN_SCOPES.get('rebuild_index_after_write') for _ in range(parallel): res = queue_result.get() rows = res.get('written_rows', 0) duration = res.get('duration', 0) if rows > 0: total_rows += rows if duration > max_duration: max_duration = duration print(f"[{datetime.now()}] [Writer-All] All workers finished, written {total_rows} rows, " f"max duration: {max_duration:.2f}s.\n" f"[{datetime.now()}] [Writer-All] Write Throughput: {total_rows / max_duration:.2f} rows/s.") if rebuild_index_after_write: rebuild_index() def rebuild_index(): print_simple_line() print(f"[{datetime.now()}] [Rebuild-index] input param rebuild_index_after_write=true.") print(f'[{datetime.now()}] [Rebuild-index] [start].') client = get_target_client() db = RUN_SCOPES.get('target_db') collection = RUN_SCOPES.get('target_collection') try: client.rebuild_index( database_name=db, collection_name=collection ) wait_rebuilding(client, db, collection) finally: client.close() def wait_rebuilding(client: RPCVectorDBClient, db, collection): print(f'[{datetime.now()}] [Rebuild-index] [running] wait building...') build_index_start = time.perf_counter() time.sleep(10) index_status = client.describe_collection(db, collection).index_status['status'] while index_status not in ['ready', 'failed']: print(f"[{datetime.now()}] [Rebuild-index] [running] collection index status: {index_status}, waiting ready") time.sleep(4) index_status = client.describe_collection(db, collection).index_status['status'] print(f"[{datetime.now()}] [Rebuild-index] [end]. collection index status: {index_status}, " f"duration: {time.perf_counter() - build_index_start:.2f}s.") def start_reader(): print(f"[{datetime.now()}] [Reader]: start reader process.") parallel = RUN_SCOPES.get('parallel') queue = Queue(maxsize=(parallel * 2)) producer = Process(target=reader_process, args=(queue, parallel, RUN_SCOPES,)) producer.start() return parallel, queue def start_writer(parallel, queue): print(f"[{datetime.now()}] [Writer-All]: start writer process.") writer_processes = [] result_queue = Queue(maxsize=parallel) for i in range(parallel): process = Process(target=writer_process, args=(i, queue, result_queue, RUN_SCOPES,)) writer_processes.append(process) process.start() for process in writer_processes: process.join() return result_queue # 迁移主入口 # 1. 启动1个reader源端读取数据,并放入生产队列 # 2. 启动parallel个writer从消费队列中获取,并写入目标端 # 3. 写入完成后,根据参数决定是否rebuild index def migrate_main(): parallel, queue = start_reader() result_queue = start_writer(parallel, queue) after_writer(result_queue, parallel) if __name__ == "__main__": init_args() pre_migrate() migrate_main()