diff --git a/graphgen/bases/base_operator.py b/graphgen/bases/base_operator.py index 30c71271..ad7c3a7c 100644 --- a/graphgen/bases/base_operator.py +++ b/graphgen/bases/base_operator.py @@ -92,8 +92,11 @@ def __call__( is_first = True for res in result: yield pd.DataFrame([res]) - self.store([res], meta_update if is_first else {}) + self.store( + [res], meta_update if is_first else {}, flush=False + ) is_first = False + self.kv_storage.index_done_callback() else: yield pd.DataFrame(result) self.store(result, meta_update) @@ -141,7 +144,7 @@ def split(self, batch: "pd.DataFrame") -> tuple["pd.DataFrame", "pd.DataFrame"]: recovered_chunks = [c for c in recovered_chunks if c is not None] return to_process, pd.DataFrame(recovered_chunks) - def store(self, results: list, meta_update: dict): + def store(self, results: list, meta_update: dict, flush: bool = True): results = convert_to_serializable(results) meta_update = convert_to_serializable(meta_update) @@ -159,7 +162,8 @@ def store(self, results: list, meta_update: dict): for v in v_list: inverse_meta[v] = k self.kv_storage.update({"_meta_inverse": inverse_meta}) - self.kv_storage.index_done_callback() + if flush: + self.kv_storage.index_done_callback() @abstractmethod def process(self, batch: list) -> Tuple[Union[list, Iterable[dict]], dict]: diff --git a/graphgen/storage/kv/rocksdb_storage.py b/graphgen/storage/kv/rocksdb_storage.py index d1361169..9b373791 100644 --- a/graphgen/storage/kv/rocksdb_storage.py +++ b/graphgen/storage/kv/rocksdb_storage.py @@ -1,3 +1,4 @@ +import logging import os from dataclasses import dataclass from typing import Any, Dict, List, Set @@ -8,6 +9,8 @@ from graphgen.bases.base_storage import BaseKVStorage +logger = logging.getLogger(__name__) + @dataclass class RocksDBKVStorage(BaseKVStorage): @@ -17,8 +20,10 @@ class RocksDBKVStorage(BaseKVStorage): def __post_init__(self): self._db_path = os.path.join(self.working_dir, f"{self.namespace}.db") self._db = Rdict(self._db_path) - print( - f"RocksDBKVStorage initialized for namespace '{self.namespace}' at '{self._db_path}'" + logger.debug( + "RocksDBKVStorage initialized for namespace '%s' at '%s'", + self.namespace, + self._db_path, ) @property @@ -30,7 +35,7 @@ def all_keys(self) -> List[str]: def index_done_callback(self): self._db.flush() - print(f"RocksDB flushed for {self.namespace}") + logger.debug("RocksDB flushed for %s", self.namespace) def get_by_id(self, id: str) -> Any: return self._db.get(id, None)