From 55e7049bf3ffe223995b5f1437bce6dcd892604a Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Fri, 6 Oct 2023 19:52:54 -0700 Subject: [PATCH 01/36] refactor(BaseLoader): split read_data and parse_data --- pyTigerGraph/gds/dataloaders.py | 1351 +++++++++++++++++++++---------- tests/test_gds_BaseLoader.py | 60 +- 2 files changed, 925 insertions(+), 486 deletions(-) diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index 5205cc7c..63b15120 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -18,7 +18,7 @@ from threading import Event, Thread from time import sleep import pickle -from typing import TYPE_CHECKING, Any, Iterator, NoReturn, Tuple, Union, Callable +from typing import TYPE_CHECKING, Any, Iterator, NoReturn, Tuple, Union, Callable, List, Dict #import re #RE_SPLITTER = re.compile(r',(?![^\[]*\])') @@ -548,13 +548,12 @@ def _request_kafka( tgraph.abortQuery(resp) @staticmethod - def _request_rest( + def _request_graph_rest( tgraph: "TigerGraphConnection", query_name: str, read_task_q: Queue, timeout: int = 600000, payload: dict = {}, - resp_type: 'Literal["both", "vertex", "edge"]' = "both", ) -> NoReturn: # Run query resp = tgraph.runInstalledQuery( @@ -562,21 +561,30 @@ def _request_rest( ) # Put raw data into reading queue for i in resp: - if resp_type == "both": - data = (i["vertex_batch"], i["edge_batch"]) - elif resp_type == "vertex": - data = i["vertex_batch"] - elif resp_type == "edge": - data = i["edge_batch"] - read_task_q.put(data) + read_task_q.put((i["vertex_batch"], i["edge_batch"])) read_task_q.put(None) @staticmethod - def _download_from_kafka( + def _request_unimode_rest( + tgraph: "TigerGraphConnection", + query_name: str, + read_task_q: Queue, + timeout: int = 600000, + payload: dict = {}, + ) -> NoReturn: + # Run query + resp = tgraph.runInstalledQuery( + query_name, params=payload, timeout=timeout, usePost=True + ) + # Put raw data into reading queue + for i in resp: + read_task_q.put(i["data_batch"]) + read_task_q.put(None) + + @staticmethod + def _download_graph_kafka( exit_event: Event, read_task_q: Queue, - num_batches: int, - out_tuple: bool, kafka_consumer: "KafkaConsumer", max_wait_time: int = 300 ) -> NoReturn: @@ -584,8 +592,6 @@ def _download_from_kafka( buffer = {} wait_time = 0 while (not exit_event.is_set()) and (wait_time < max_wait_time): - if delivered_batch == num_batches: - break resp = kafka_consumer.poll(1000) if not resp: wait_time += 1 @@ -594,40 +600,55 @@ def _download_from_kafka( for msgs in resp.values(): for message in msgs: key = message.key.decode("utf-8") - if out_tuple: - if key.startswith("vertex"): - companion_key = key.replace("vertex", "edge") - if companion_key in buffer: - read_task_q.put((message.value.decode("utf-8"), - buffer[companion_key])) - del buffer[companion_key] - delivered_batch += 1 - else: - buffer[key] = message.value.decode("utf-8") - elif key.startswith("edge"): - companion_key = key.replace("edge", "vertex") - if companion_key in buffer: - read_task_q.put((buffer[companion_key], - message.value.decode("utf-8"))) - del buffer[companion_key] - delivered_batch += 1 - else: - buffer[key] = message.value.decode("utf-8") + if key.startswith("vertex"): + companion_key = key.replace("vertex", "edge") + if companion_key in buffer: + read_task_q.put((message.value.decode("utf-8"), + buffer[companion_key])) + del buffer[companion_key] + delivered_batch += 1 else: - raise ValueError( - "Unrecognized key {} for messages in kafka".format(key) - ) + buffer[key] = message.value.decode("utf-8") + elif key.startswith("edge"): + companion_key = key.replace("edge", "vertex") + if companion_key in buffer: + read_task_q.put((buffer[companion_key], + message.value.decode("utf-8"))) + del buffer[companion_key] + delivered_batch += 1 + else: + buffer[key] = message.value.decode("utf-8") else: - read_task_q.put(message.value.decode("utf-8")) - delivered_batch += 1 + raise ValueError( + "Unrecognized key {} for messages in kafka".format(key) + ) + read_task_q.put(None) + + def _download_unimode_kafka( + exit_event: Event, + read_task_q: Queue, + kafka_consumer: "KafkaConsumer", + max_wait_time: int = 300 + ) -> NoReturn: + delivered_batch = 0 + wait_time = 0 + while (not exit_event.is_set()) and (wait_time < max_wait_time): + resp = kafka_consumer.poll(1000) + if not resp: + wait_time += 1 + continue + wait_time = 0 + for msgs in resp.values(): + for message in msgs: + read_task_q.put(message.value.decode("utf-8")) + delivered_batch += 1 read_task_q.put(None) @staticmethod - def _read_data( + def _read_graph_data( exit_event: Event, in_q: Queue, out_q: Queue, - in_format: str = "vertex", out_format: str = "dataframe", v_in_feats: Union[list, dict] = [], v_out_labels: Union[list, dict] = [], @@ -641,8 +662,51 @@ def _read_data( delimiter: str = "|", reindex: bool = True, is_hetero: bool = False, - callback_fn: Callable = None, + callback_fn: Callable = None ) -> NoReturn: + # Import the right libraries based on output format + out_format = out_format.lower() + if out_format == "pyg" or out_format == "dgl": + try: + import torch + except ImportError: + raise ImportError( + "PyTorch is not installed. Please install it to use PyG or DGL output." + ) + if out_format == "dgl": + try: + import dgl + except ImportError: + raise ImportError( + "DGL is not installed. Please install DGL to use DGL format." + ) + elif out_format == "pyg": + try: + import torch_geometric as pyg + except ImportError: + raise ImportError( + "PyG is not installed. Please install PyG to use PyG format." + ) + elif out_format.lower() == "spektral": + try: + import tensorflow as tf + except ImportError: + raise ImportError( + "Tensorflow is not installed. Please install it to use spektral output." + ) + try: + import scipy + except ImportError: + raise ImportError( + "scipy is not installed. Please install it to use spektral output." + ) + try: + import spektral + except ImportError: + raise ImportError( + "Spektral is not installed. Please install it to use spektral output." + ) + # Get raw data from queue and parse while not exit_event.is_set(): raw = in_q.get() if raw is None: @@ -650,10 +714,8 @@ def _read_data( out_q.put(None) break try: - data = BaseLoader._parse_data( + data = BaseLoader._parse_graph_data_to_df( raw = raw, - in_format = in_format, - out_format = out_format, v_in_feats = v_in_feats, v_out_labels = v_out_labels, v_extra_feats = v_extra_feats, @@ -662,28 +724,575 @@ def _read_data( e_out_labels = e_out_labels, e_extra_feats = e_extra_feats, e_attr_types = e_attr_types, - add_self_loop = add_self_loop, delimiter = delimiter, - reindex = reindex, primary_id = {}, is_hetero = is_hetero, - callback_fn = callback_fn ) + if out_format == "pyg": + data = BaseLoader._parse_df_to_pyg( + raw = data, + v_in_feats = v_in_feats, + v_out_labels = v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = e_in_feats, + e_out_labels = e_out_labels, + e_extra_feats = e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = add_self_loop, + reindex = reindex, + is_hetero = is_hetero, + torch = torch, + pyg = pyg + ) + elif out_format == "dgl": + data = BaseLoader._parse_df_to_dgl( + raw = data, + v_in_feats = v_in_feats, + v_out_labels = v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = e_in_feats, + e_out_labels = e_out_labels, + e_extra_feats = e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = add_self_loop, + reindex = reindex, + is_hetero = is_hetero, + torch = torch, + dgl= dgl + ) + elif out_format == "spektral" and is_hetero==False: + data = BaseLoader._parse_df_to_spektral( + raw = data, + v_in_feats = v_in_feats, + v_out_labels = v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = e_in_feats, + e_out_labels = e_out_labels, + e_extra_feats = e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = add_self_loop, + reindex = reindex, + is_hetero = is_hetero, + scipy = scipy, + spektral = spektral + ) + elif out_format == "dataframe" or out_format == "df": + pass + else: + raise NotImplementedError + if callback_fn: + data = callback_fn(data) out_q.put(data) except Exception as err: - warnings.warn("Error parsing a data batch. Set logging level to ERROR for details.") + warnings.warn("Error parsing a graph batch. Set logging level to ERROR for details.") logger.error(err, exc_info=True) logger.error("Error parsing data: {}".format(raw)) - logger.error("Parameters:\n in_format={}\n out_format={}\n v_in_feats={}\n v_out_labels={}\n v_extra_feats={}\n v_attr_types={}\n e_in_feats={}\n e_out_labels={}\n e_extra_feats={}\n e_attr_types={}\n delimiter={}\n".format( - in_format, out_format, v_in_feats, v_out_labels, v_extra_feats, v_attr_types, e_in_feats, e_out_labels, e_extra_feats, e_attr_types, delimiter)) + logger.error("Parameters:\n out_format={}\n v_in_feats={}\n v_out_labels={}\n v_extra_feats={}\n v_attr_types={}\n e_in_feats={}\n e_out_labels={}\n e_extra_feats={}\n e_attr_types={}\n delimiter={}\n".format( + out_format, v_in_feats, v_out_labels, v_extra_feats, v_attr_types, e_in_feats, e_out_labels, e_extra_feats, e_attr_types, delimiter)) + + @staticmethod + def _read_vertex_data( + exit_event: Event, + in_q: Queue, + out_q: Queue, + v_in_feats: Union[list, dict] = [], + v_out_labels: Union[list, dict] = [], + v_extra_feats: Union[list, dict] = [], + v_attr_types: dict = {}, + delimiter: str = "|", + is_hetero: bool = False, + callback_fn: Callable = None + ) -> NoReturn: + while not exit_event.is_set(): + raw = in_q.get() + if raw is None: + in_q.task_done() + out_q.put(None) + break + try: + data = BaseLoader._parse_vertex_data( + raw = raw, + v_in_feats = v_in_feats, + v_out_labels = v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + delimiter = delimiter, + is_hetero = is_hetero + ) + if callback_fn: + data = callback_fn(data) + out_q.put(data) + except Exception as err: + warnings.warn("Error parsing a vertex batch. Set logging level to ERROR for details.") + logger.error(err, exc_info=True) + logger.error("Error parsing data: {}".format(raw)) + logger.error("Parameters:\n v_in_feats={}\n v_out_labels={}\n v_extra_feats={}\n v_attr_types={}\n delimiter={}\n".format( + v_in_feats, v_out_labels, v_extra_feats, v_attr_types, delimiter)) - in_q.task_done() + @staticmethod + def _read_edge_data( + exit_event: Event, + in_q: Queue, + out_q: Queue, + e_in_feats: Union[list, dict] = [], + e_out_labels: Union[list, dict] = [], + e_extra_feats: Union[list, dict] = [], + e_attr_types: dict = {}, + delimiter: str = "|", + is_hetero: bool = False, + callback_fn: Callable = None + ) -> NoReturn: + while not exit_event.is_set(): + raw = in_q.get() + if raw is None: + in_q.task_done() + out_q.put(None) + break + try: + data = BaseLoader._parse_edge_data( + raw = raw, + e_in_feats = e_in_feats, + e_out_labels = e_out_labels, + e_extra_feats = e_extra_feats, + e_attr_types = e_attr_types, + delimiter = delimiter, + is_hetero = is_hetero + ) + if callback_fn: + data = callback_fn(data) + out_q.put(data) + except Exception as err: + warnings.warn("Error parsing an edge batch. Set logging level to ERROR for details.") + logger.error(err, exc_info=True) + logger.error("Error parsing data: {}".format(raw)) + logger.error("Parameters:\n e_in_feats={}\n e_out_labels={}\n e_extra_feats={}\n e_attr_types={}\n delimiter={}\n".format( + e_in_feats, e_out_labels, e_extra_feats, e_attr_types, delimiter)) + + @staticmethod + def _parse_vertex_data( + raw: List[str], + v_in_feats: Union[list, dict] = [], + v_out_labels: Union[list, dict] = [], + v_extra_feats: Union[list, dict] = [], + v_attr_types: dict = {}, + delimiter: str = "|", + is_hetero: bool = False + ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: + """Parse raw vertex data into dataframes. + """ + # Read in vertex CSVs as dataframes + if not is_hetero: + # String of vertices in format vid,v_in_feats,v_out_labels,v_extra_feats + v_attributes = ["vid"] + v_in_feats + v_out_labels + v_extra_feats + v_file = (line.strip().split(delimiter) for line in raw.splitlines()) + data = pd.DataFrame(v_file, columns=v_attributes) + for column in data.columns: + data[column] = pd.to_numeric(data[column], errors="ignore") + for v_attr in v_attributes: + if v_attr_types.get(v_attr, "") == "MAP": + # I am sorry that this is this ugly... + data[v_attr] = data[v_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) + else: + # String of vertices in format vtype,vid,v_in_feats,v_out_labels,v_extra_feats + v_file = (line.strip().split(delimiter) for line in raw.splitlines()) + v_file_dict = defaultdict(list) + for line in v_file: + v_file_dict[line[0]].append(line[1:]) + vertices = {} + for vtype in v_file_dict: + v_attributes = ["vid"] + \ + v_in_feats.get(vtype, []) + \ + v_out_labels.get(vtype, []) + \ + v_extra_feats.get(vtype, []) + vertices[vtype] = pd.DataFrame(v_file_dict[vtype], columns=v_attributes) + for v_attr in v_extra_feats.get(vtype, []): + if v_attr_types[vtype][v_attr] == "MAP": + # I am sorry that this is this ugly... + vertices[vtype][v_attr] = vertices[vtype][v_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) + data = vertices + return data + + @staticmethod + def _parse_edge_data( + raw: List[str], + e_in_feats: Union[list, dict] = [], + e_out_labels: Union[list, dict] = [], + e_extra_feats: Union[list, dict] = [], + e_attr_types: dict = {}, + delimiter: str = "|", + is_hetero: bool = False + ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: + """Parse raw edge data into dataframes. + """ + # Read in edge CSVs as dataframes + if not is_hetero: + # String of edges in format source_vid,target_vid,... + e_attributes = ["source", "target"] + e_in_feats + e_out_labels + e_extra_feats + #file = "\n".join(x for x in raw.split("\n") if x.strip()) + #data = pd.read_table(io.StringIO(file), header=None, names=e_attributes, sep=delimiter) + e_file = (line.strip().split(delimiter) for line in raw.splitlines()) + data = pd.DataFrame(e_file, columns=e_attributes) + for column in data.columns: + data[column] = pd.to_numeric(data[column], errors="ignore") + for e_attr in e_attributes: + if e_attr_types.get(e_attr, "") == "MAP": + # I am sorry that this is this ugly... + data[e_attr] = data[e_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) + else: + # String of edges in format etype,source_vid,target_vid,... + e_file = (line.strip().split(delimiter) for line in raw.splitlines()) + e_file_dict = defaultdict(list) + for line in e_file: + e_file_dict[line[0]].append(line[1:]) + edges = {} + for etype in e_file_dict: + e_attributes = ["source", "target"] + \ + e_in_feats.get(etype, []) + \ + e_out_labels.get(etype, []) + \ + e_extra_feats.get(etype, []) + edges[etype] = pd.DataFrame(e_file_dict[etype], columns=e_attributes) + for e_attr in e_extra_feats.get(etype, []): + if e_attr_types[etype][e_attr] == "MAP": + # I am sorry that this is this ugly... + edges[etype][e_attr] = edges[etype][e_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) + del e_file_dict, e_file + data = edges + + return data + + @staticmethod + def _parse_graph_data_to_df( + raw: Tuple[List[str], List[str]], + v_in_feats: Union[list, dict] = [], + v_out_labels: Union[list, dict] = [], + v_extra_feats: Union[list, dict] = [], + v_attr_types: dict = {}, + e_in_feats: Union[list, dict] = [], + e_out_labels: Union[list, dict] = [], + e_extra_feats: Union[list, dict] = [], + e_attr_types: dict = {}, + delimiter: str = "|", + primary_id: dict = {}, + is_hetero: bool = False + ) -> Union[Tuple[pd.DataFrame, pd.DataFrame], Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]]: + """Parse raw data into dataframes. + """ + # Read in vertex and edge CSVs as dataframes + # A pair of in-memory CSVs (vertex, edge) + v_file, e_file = raw + if not is_hetero: + v_attributes = ["vid"] + v_in_feats + v_out_labels + v_extra_feats + e_attributes = ["source", "target"] + e_in_feats + e_out_labels + e_extra_feats + v_file = (line.split(delimiter) for line in v_file.splitlines()) + vertices = pd.DataFrame(v_file, columns=v_attributes) + for column in vertices.columns: + vertices[column] = pd.to_numeric(vertices[column], errors="ignore") + for v_attr in v_extra_feats: + if v_attr_types[v_attr] == "MAP": + # I am sorry that this is this ugly... + vertices[v_attr] = vertices[v_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) + if primary_id: + id_map = pd.DataFrame({"vid": primary_id.keys(), "primary_id": primary_id.values()}) + vertices = vertices.merge(id_map.astype({"vid": vertices["vid"].dtype}), on="vid") + v_extra_feats.append("primary_id") + e_file = (line.split(delimiter) for line in e_file.splitlines()) + edges = pd.DataFrame(e_file, columns=e_attributes) + for column in edges.columns: + edges[column] = pd.to_numeric(edges[column], errors="ignore") + for e_attr in e_attributes: + if e_attr_types.get(e_attr, "") == "MAP": + # I am sorry that this is this ugly... + edges[e_attr] = edges[e_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) + else: + v_file = (line.split(delimiter) for line in v_file.splitlines()) + v_file_dict = defaultdict(list) + for line in v_file: + v_file_dict[line[0]].append(line[1:]) + vertices = {} + for vtype in v_file_dict: + v_attributes = ["vid"] + \ + v_in_feats.get(vtype, []) + \ + v_out_labels.get(vtype, []) + \ + v_extra_feats.get(vtype, []) + vertices[vtype] = pd.DataFrame(v_file_dict[vtype], columns=v_attributes, dtype="object") + for v_attr in v_extra_feats.get(vtype, []): + if v_attr_types[vtype][v_attr] == "MAP": + # I am sorry that this is this ugly... + vertices[vtype][v_attr] = vertices[vtype][v_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) + if primary_id: + id_map = pd.DataFrame({"vid": primary_id.keys(), "primary_id": primary_id.values()}, + dtype="object") + for vtype in vertices: + vertices[vtype] = vertices[vtype].merge(id_map, on="vid") + e_file = (line.split(delimiter) for line in e_file.splitlines()) + e_file_dict = defaultdict(list) + for line in e_file: + e_file_dict[line[0]].append(line[1:]) + edges = {} + for etype in e_file_dict: + e_attributes = ["source", "target"] + \ + e_in_feats.get(etype, []) + \ + e_out_labels.get(etype, []) + \ + e_extra_feats.get(etype, []) + edges[etype] = pd.DataFrame(e_file_dict[etype], columns=e_attributes, dtype="object") + for e_attr in e_extra_feats.get(etype, []): + if e_attr_types[etype][e_attr] == "MAP": + # I am sorry that this is this ugly... + edges[etype][e_attr] = edges[etype][e_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) + return (vertices, edges) + + @staticmethod + def _parse_df_to_pyg( + raw: Union[Tuple[pd.DataFrame, pd.DataFrame], Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]], + v_in_feats: Union[list, dict] = [], + v_out_labels: Union[list, dict] = [], + v_extra_feats: Union[list, dict] = [], + v_attr_types: dict = {}, + e_in_feats: Union[list, dict] = [], + e_out_labels: Union[list, dict] = [], + e_extra_feats: Union[list, dict] = [], + e_attr_types: dict = {}, + add_self_loop: bool = False, + reindex: bool = True, + is_hetero: bool = False, + torch = None, + pyg = None + ) -> Union["pyg.data.Data", "pyg.data.HeteroData"]: + """Parse dataframes to PyG graphs. + """ + def attr_to_tensor( + attributes: list, attr_types: dict, df: pd.DataFrame + ) -> "torch.Tensor": + """Turn multiple columns of a dataframe into a tensor. + """ + x = [] + for col in attributes: + dtype = attr_types[col].lower() + if dtype.startswith("str"): + raise TypeError( + "String type not allowed for input and output features." + ) + if dtype.startswith("list"): + dtype2 = dtype.split(":")[1] + x.append(df[col].str.split(expand=True).to_numpy().astype(dtype2)) + elif dtype.startswith("set") or dtype.startswith("map") or dtype.startswith("date"): + raise NotImplementedError( + "{} type not supported for input and output features yet.".format(dtype)) + elif dtype == "bool": + x.append(df[[col]].astype("int8").to_numpy().astype(dtype)) + elif dtype == "uint": + # PyTorch only supports uint8. Need to convert it to int. + x.append(df[[col]].to_numpy().astype("int")) + else: + x.append(df[[col]].to_numpy().astype(dtype)) + return torch.tensor(np.hstack(x)).squeeze(dim=1) + + def add_attributes(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, + graph, is_hetero: bool, feat_name: str, + target: 'Literal["edge", "vertex"]', vetype: str = None) -> None: + """Add multiple attributes as a single feature to edges or vertices. + """ + if is_hetero: + if not vetype: + raise ValueError("Vertex or edge type required for heterogeneous graphs") + # Focus on a specific type + if target == "edge": + data = graph[attr_types["FromVertexTypeName"], + vetype, + attr_types["ToVertexTypeName"]] + elif target == "vertex": + data = graph[vetype] + else: + data = graph + data[feat_name] = attr_to_tensor(attr_names, attr_types, attr_df) + + def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, + graph, is_hetero: bool, + target: 'Literal["edge", "vertex"]', vetype: str = None) -> None: + """Add each attribute as a single feature to edges or vertices. + """ + if is_hetero: + if not vetype: + raise ValueError("Vertex or edge type required for heterogeneous graphs") + # Focus on a specific type + if target == "edge": + data = graph[attr_types["FromVertexTypeName"], + vetype, + attr_types["ToVertexTypeName"]] + elif target == "vertex": + data = graph[vetype] + else: + data = graph + + for col in attr_names: + dtype = attr_types[col].lower() + if dtype.startswith("str") or dtype.startswith("map"): + data[col] = attr_df[col].to_list() + elif dtype.startswith("list"): + dtype2 = dtype.split(":")[1] + if dtype2.startswith("str"): + data[col] = attr_df[col].str.split().to_list() + else: + data[col] = torch.tensor( + attr_df[col] + .str.split(expand=True) + .to_numpy() + .astype(dtype2) + ) + elif dtype.startswith("set") or dtype.startswith("date"): + raise NotImplementedError( + "{} type not supported for extra features yet.".format(dtype)) + elif dtype == "bool": + data[col] = torch.tensor( + attr_df[col].astype("int8").astype(dtype) + ) + elif dtype == "uint": + # PyTorch only supports uint8. Need to convert it to int. + data[col] = torch.tensor( + attr_df[col].astype("int") + ) + else: + data[col] = torch.tensor( + attr_df[col].astype(dtype) + ) + + # Convert dataframes into PyG graphs + # Reformat as a graph. + # Need to have a pair of tables for edges and vertices. + vertices, edges = raw + if not is_hetero: + # Deal with edgelist first + if reindex: + vertices["tmp_id"] = range(len(vertices)) + id_map = vertices[["vid", "tmp_id"]] + edges = edges.merge(id_map, left_on="source", right_on="vid") + edges.drop(columns=["source", "vid"], inplace=True) + edges = edges.merge(id_map, left_on="target", right_on="vid") + edges.drop(columns=["target", "vid"], inplace=True) + edgelist = edges[["tmp_id_x", "tmp_id_y"]] + else: + edgelist = edges[["source", "target"]] + + edgelist = torch.tensor(edgelist.to_numpy().T, dtype=torch.long) + if add_self_loop: + edgelist = pyg.utils.add_self_loops(edgelist)[0] + data = pyg.data.Data() + data["edge_index"] = edgelist + # Deal with edge attributes + if e_in_feats: + add_attributes(e_in_feats, e_attr_types, edges, + data, is_hetero, "edge_feat", "edge") + if e_out_labels: + add_attributes(e_out_labels, e_attr_types, edges, + data, is_hetero, "edge_label", "edge") + if e_extra_feats: + add_sep_attr(e_extra_feats, e_attr_types, edges, + data, is_hetero, "edge") + # Deal with vertex attributes next + if v_in_feats: + add_attributes(v_in_feats, v_attr_types, vertices, + data, is_hetero, "x", "vertex") + if v_out_labels: + add_attributes(v_out_labels, v_attr_types, vertices, + data, is_hetero, "y", "vertex") + if v_extra_feats: + add_sep_attr(v_extra_feats, v_attr_types, vertices, + data, is_hetero, "vertex") + else: + # Heterogeneous graph + # Deal with edgelist first + edgelist = {} + if reindex: + id_map = {} + for vtype in vertices: + vertices[vtype]["tmp_id"] = range(len(vertices[vtype])) + id_map[vtype] = vertices[vtype][["vid", "tmp_id"]] + for etype in edges: + source_type = e_attr_types[etype]["FromVertexTypeName"] + target_type = e_attr_types[etype]["ToVertexTypeName"] + if e_attr_types[etype]["IsDirected"] or source_type==target_type: + edges[etype] = edges[etype].merge(id_map[source_type], left_on="source", right_on="vid") + edges[etype].drop(columns=["source", "vid"], inplace=True) + edges[etype] = edges[etype].merge(id_map[target_type], left_on="target", right_on="vid") + edges[etype].drop(columns=["target", "vid"], inplace=True) + edgelist[etype] = edges[etype][["tmp_id_x", "tmp_id_y"]] + else: + subdf1 = edges[etype].merge(id_map[source_type], left_on="source", right_on="vid") + subdf1.drop(columns=["source", "vid"], inplace=True) + subdf1 = subdf1.merge(id_map[target_type], left_on="target", right_on="vid") + subdf1.drop(columns=["target", "vid"], inplace=True) + if len(subdf1) < len(edges[etype]): + subdf2 = edges[etype].merge(id_map[source_type], left_on="target", right_on="vid") + subdf2.drop(columns=["target", "vid"], inplace=True) + subdf2 = subdf2.merge(id_map[target_type], left_on="source", right_on="vid") + subdf2.drop(columns=["source", "vid"], inplace=True) + subdf1 = pd.concat((subdf1, subdf2), ignore_index=True) + edges[etype] = subdf1 + edgelist[etype] = edges[etype][["tmp_id_x", "tmp_id_y"]] + else: + for etype in edges: + edgelist[etype] = edges[etype][["source", "target"]] + for etype in edges: + edgelist[etype] = torch.tensor(edgelist[etype].to_numpy().T, dtype=torch.long) + if add_self_loop: + edgelist[etype] = pyg.utils.add_self_loops(edgelist[etype])[0] + data = pyg.data.HeteroData() + for etype in edgelist: + data[e_attr_types[etype]["FromVertexTypeName"], + etype, + e_attr_types[etype]["ToVertexTypeName"]].edge_index = edgelist[etype] + # Deal with edge attributes + if e_in_feats: + for etype in edges: + if etype not in e_in_feats: + continue + if e_in_feats[etype]: + add_attributes(e_in_feats[etype], e_attr_types[etype], edges[etype], + data, is_hetero, "edge_feat", "edge", etype) + if e_out_labels: + for etype in edges: + if etype not in e_out_labels: + continue + if e_out_labels[etype]: + add_attributes(e_out_labels[etype], e_attr_types[etype], edges[etype], + data, is_hetero, "edge_label", "edge", etype) + if e_extra_feats: + for etype in edges: + if etype not in e_extra_feats: + continue + if e_extra_feats[etype]: + add_sep_attr(e_extra_feats[etype], e_attr_types[etype], edges[etype], + data, is_hetero, "edge", etype) + # Deal with vertex attributes next + if v_in_feats: + for vtype in vertices: + if vtype not in v_in_feats: + continue + if v_in_feats[vtype]: + add_attributes(v_in_feats[vtype], v_attr_types[vtype], vertices[vtype], + data, is_hetero, "x", "vertex", vtype) + if v_out_labels: + for vtype in vertices: + if vtype not in v_out_labels: + continue + if v_out_labels[vtype]: + add_attributes(v_out_labels[vtype], v_attr_types[vtype], vertices[vtype], + data, is_hetero, "y", "vertex", vtype) + if v_extra_feats: + for vtype in vertices: + if vtype not in v_extra_feats: + continue + if v_extra_feats[vtype]: + add_sep_attr(v_extra_feats[vtype], v_attr_types[vtype], vertices[vtype], + data, is_hetero, "vertex", vtype) + return data @staticmethod - def _parse_data( - raw: Union[str, Tuple[str, str]], - in_format: 'Literal["vertex", "edge", "graph"]' = "vertex", - out_format: str = "dataframe", + def _parse_df_to_dgl( + raw: Union[Tuple[pd.DataFrame, pd.DataFrame], Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]], v_in_feats: Union[list, dict] = [], v_out_labels: Union[list, dict] = [], v_extra_feats: Union[list, dict] = [], @@ -693,14 +1302,12 @@ def _parse_data( e_extra_feats: Union[list, dict] = [], e_attr_types: dict = {}, add_self_loop: bool = False, - delimiter: str = "|", reindex: bool = True, - primary_id: dict = {}, is_hetero: bool = False, - callback_fn: Callable = None, - ) -> Union[pd.DataFrame, Tuple[pd.DataFrame, pd.DataFrame], "dgl.DGLGraph", "pyg.data.Data", "spektral.data.graph.Graph", - dict, Tuple[dict, dict], "pyg.data.HeteroData"]: - """Parse raw data into dataframes, DGL graphs, or PyG graphs. + torch = None, + dgl = None + ) -> Union["dgl.graph", "dgl.heterograph"]: + """Parse dataframes to PyG graphs. """ def attr_to_tensor( attributes: list, attr_types: dict, df: pd.DataFrame @@ -727,16 +1334,10 @@ def attr_to_tensor( x.append(df[[col]].to_numpy().astype("int")) else: x.append(df[[col]].to_numpy().astype(dtype)) - if mode == "pyg" or mode == "dgl": - return torch.tensor(np.hstack(x)).squeeze(dim=1) - elif mode == "spektral": - try: - return np.squeeze(np.hstack(x), axis=1) #throws an error if axis isn't 1 - except: - return np.hstack(x) + return torch.tensor(np.hstack(x)).squeeze(dim=1) def add_attributes(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, - graph, is_hetero: bool, mode: str, feat_name: str, + graph, is_hetero: bool, feat_name: str, target: 'Literal["edge", "vertex"]', vetype: str = None) -> None: """Add multiple attributes as a single feature to edges or vertices. """ @@ -744,31 +1345,20 @@ def add_attributes(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, if not vetype: raise ValueError("Vertex or edge type required for heterogeneous graphs") # Focus on a specific type - if mode == "pyg": - if target == "edge": - data = graph[attr_types["FromVertexTypeName"], - vetype, - attr_types["ToVertexTypeName"]] - elif target == "vertex": - data = graph[vetype] - elif mode == "dgl": - if target == "edge": - data = graph.edges[vetype].data - elif target == "vertex": - data = graph.nodes[vetype].data + if target == "edge": + data = graph.edges[vetype].data + elif target == "vertex": + data = graph.nodes[vetype].data else: - if mode == "pyg" or mode == "spektral": - data = graph - elif mode == "dgl": - if target == "edge": - data = graph.edata - elif target == "vertex": - data = graph.ndata + if target == "edge": + data = graph.edata + elif target == "vertex": + data = graph.ndata data[feat_name] = attr_to_tensor(attr_names, attr_types, attr_df) def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, - graph, is_hetero: bool, mode: str, + graph, is_hetero: bool, target: 'Literal["edge", "vertex"]', vetype: str = None) -> None: """Add each attribute as a single feature to edges or vertices. """ @@ -776,293 +1366,69 @@ def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, if not vetype: raise ValueError("Vertex or edge type required for heterogeneous graphs") # Focus on a specific type - if mode == "pyg": - if target == "edge": - data = graph[attr_types["FromVertexTypeName"], - vetype, - attr_types["ToVertexTypeName"]] - elif target == "vertex": - data = graph[vetype] - elif mode == "dgl": - if target == "edge": - data = graph.edges[vetype].data - elif target == "vertex": - data = graph.nodes[vetype].data + if target == "edge": + data = graph.edges[vetype].data + elif target == "vertex": + data = graph.nodes[vetype].data else: - if mode == "pyg" or mode == "spektral": - data = graph - elif mode == "dgl": - if target == "edge": - data = graph.edata - elif target == "vertex": - data = graph.ndata + if target == "edge": + data = graph.edata + elif target == "vertex": + data = graph.ndata for col in attr_names: dtype = attr_types[col].lower() if dtype.startswith("str") or dtype.startswith("map"): - if mode == "dgl": + if vetype is None: + # Homogeneous graph, add column directly to extra data + graph.extra_data[col] = attr_df[col].to_list() + elif vetype not in graph.extra_data: + # Hetero graph, vetype doesn't exist in extra data + graph.extra_data[vetype] = {} + graph.extra_data[vetype][col] = attr_df[col].to_list() + else: + # Hetero graph and vetype already exists + graph.extra_data[vetype][col] = attr_df[col].to_list() + elif dtype.startswith("list"): + dtype2 = dtype.split(":")[1] + if dtype2.startswith("str"): if vetype is None: # Homogeneous graph, add column directly to extra data - graph.extra_data[col] = attr_df[col].to_list() + graph.extra_data[col] = attr_df[col].str.split().to_list() elif vetype not in graph.extra_data: # Hetero graph, vetype doesn't exist in extra data graph.extra_data[vetype] = {} - graph.extra_data[vetype][col] = attr_df[col].to_list() + graph.extra_data[vetype][col] = attr_df[col].str.split().to_list() else: # Hetero graph and vetype already exists - graph.extra_data[vetype][col] = attr_df[col].to_list() - elif mode == "pyg" or mode == "spektral": - data[col] = attr_df[col].to_list() - elif dtype.startswith("list"): - dtype2 = dtype.split(":")[1] - if dtype2.startswith("str"): - if mode == "dgl": - if vetype is None: - # Homogeneous graph, add column directly to extra data - graph.extra_data[col] = attr_df[col].str.split().to_list() - elif vetype not in graph.extra_data: - # Hetero graph, vetype doesn't exist in extra data - graph.extra_data[vetype] = {} - graph.extra_data[vetype][col] = attr_df[col].str.split().to_list() - else: - # Hetero graph and vetype already exists - graph.extra_data[vetype][col] = attr_df[col].str.split().to_list() - elif mode == "pyg" or mode == "spektral": - data[col] = attr_df[col].str.split().to_list() + graph.extra_data[vetype][col] = attr_df[col].str.split().to_list() else: - if mode == "pyg" or mode == "dgl": - data[col] = torch.tensor( - attr_df[col] - .str.split(expand=True) - .to_numpy() - .astype(dtype2) - ) - elif mode == "spektral": - data[col] = attr_df[col].str.split(expand=True).to_numpy().astype(dtype2) + data[col] = torch.tensor( + attr_df[col] + .str.split(expand=True) + .to_numpy() + .astype(dtype2) + ) elif dtype.startswith("set") or dtype.startswith("date"): raise NotImplementedError( "{} type not supported for extra features yet.".format(dtype)) elif dtype == "bool": - if mode == "pyg" or mode == "dgl": - data[col] = torch.tensor( - attr_df[col].astype("int8").astype(dtype) - ) - elif mode == "spektral": - data[col] = attr_df[col].astype("int8").astype(dtype) + data[col] = torch.tensor( + attr_df[col].astype("int8").astype(dtype) + ) elif dtype == "uint": # PyTorch only supports uint8. Need to convert it to int. - if mode == "pyg" or mode == "dgl": - data[col] = torch.tensor( - attr_df[col].astype("int") - ) - elif mode == "spektral": - data[col] = attr_df[col].astype(dtype) - else: - if mode == "pyg" or mode == "dgl": - data[col] = torch.tensor( - attr_df[col].astype(dtype) - ) - elif mode == "spektral": - data[col] = attr_df[col].astype(dtype) - - # Read in vertex and edge CSVs as dataframes - vertices, edges = None, None - if in_format == "vertex": - # String of vertices in format vid,v_in_feats,v_out_labels,v_extra_feats - if not is_hetero: - v_attributes = ["vid"] + v_in_feats + v_out_labels + v_extra_feats - v_file = (line.split(delimiter) for line in raw.split('\n') if line) - data = pd.DataFrame(v_file, columns=v_attributes) - for column in data.columns: - data[column] = pd.to_numeric(data[column], errors="ignore") - for v_attr in v_attributes: - if v_attr_types.get(v_attr, "") == "MAP": - # I am sorry that this is this ugly... - data[v_attr] = data[v_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) - else: - v_file = (line.split(delimiter) for line in raw.split('\n') if line) - v_file_dict = defaultdict(list) - for line in v_file: - v_file_dict[line[0]].append(line[1:]) - vertices = {} - for vtype in v_file_dict: - v_attributes = ["vid"] + \ - v_in_feats.get(vtype, []) + \ - v_out_labels.get(vtype, []) + \ - v_extra_feats.get(vtype, []) - vertices[vtype] = pd.DataFrame(v_file_dict[vtype], columns=v_attributes) - for v_attr in v_extra_feats.get(vtype, []): - if v_attr_types[vtype][v_attr] == "MAP": - # I am sorry that this is this ugly... - vertices[vtype][v_attr] = vertices[vtype][v_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) - data = vertices - elif in_format == "edge": - # String of edges in format source_vid,target_vid - if not is_hetero: - e_attributes = ["source", "target"] + e_in_feats + e_out_labels + e_extra_feats - #file = "\n".join(x for x in raw.split("\n") if x.strip()) - #data = pd.read_table(io.StringIO(file), header=None, names=e_attributes, sep=delimiter) - e_file = (line.split(delimiter) for line in raw.split('\n') if line) - data = pd.DataFrame(e_file, columns=e_attributes) - for column in data.columns: - data[column] = pd.to_numeric(data[column], errors="ignore") - for e_attr in e_attributes: - if e_attr_types.get(e_attr, "") == "MAP": - # I am sorry that this is this ugly... - data[e_attr] = data[e_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) - else: - e_file = (line.split(delimiter) for line in raw.split('\n') if line) - e_file_dict = defaultdict(list) - for line in e_file: - e_file_dict[line[0]].append(line[1:]) - edges = {} - for etype in e_file_dict: - e_attributes = ["source", "target"] + \ - e_in_feats.get(etype, []) + \ - e_out_labels.get(etype, []) + \ - e_extra_feats.get(etype, []) - edges[etype] = pd.DataFrame(e_file_dict[etype], columns=e_attributes) - for e_attr in e_extra_feats.get(etype, []): - if e_attr_types[etype][e_attr] == "MAP": - # I am sorry that this is this ugly... - edges[etype][e_attr] = edges[etype][e_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) - del e_file_dict, e_file - data = edges - elif in_format == "graph": - # A pair of in-memory CSVs (vertex, edge) - v_file, e_file = raw - if not is_hetero: - v_attributes = ["vid"] + v_in_feats + v_out_labels + v_extra_feats - e_attributes = ["source", "target"] + e_in_feats + e_out_labels + e_extra_feats - #file = "\n".join(x for x in v_file.split("\n") if x.strip()) - v_file = (line.split(delimiter) for line in v_file.split('\n') if line) - vertices = pd.DataFrame(v_file, columns=v_attributes) - for column in vertices.columns: - vertices[column] = pd.to_numeric(vertices[column], errors="ignore") - for v_attr in v_extra_feats: - if v_attr_types[v_attr] == "MAP": - # I am sorry that this is this ugly... - vertices[v_attr] = vertices[v_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) - if primary_id: - id_map = pd.DataFrame({"vid": primary_id.keys(), "primary_id": primary_id.values()}) - vertices = vertices.merge(id_map.astype({"vid": vertices["vid"].dtype}), on="vid") - v_extra_feats.append("primary_id") - #file = "\n".join(x for x in e_file.split("\n") if x.strip()) - e_file = (line.split(delimiter) for line in e_file.split('\n') if line) - #edges = pd.read_table(io.StringIO(file), header=None, names=e_attributes, dtype="object", sep=delimiter) - edges = pd.DataFrame(e_file, columns=e_attributes) - for column in edges.columns: - edges[column] = pd.to_numeric(edges[column], errors="ignore") - for e_attr in e_attributes: - if e_attr_types.get(e_attr, "") == "MAP": - # I am sorry that this is this ugly... - edges[e_attr] = edges[e_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) - else: - v_file = (line.split(delimiter) for line in v_file.split('\n') if line) - v_file_dict = defaultdict(list) - for line in v_file: - v_file_dict[line[0]].append(line[1:]) - vertices = {} - for vtype in v_file_dict: - v_attributes = ["vid"] + \ - v_in_feats.get(vtype, []) + \ - v_out_labels.get(vtype, []) + \ - v_extra_feats.get(vtype, []) - vertices[vtype] = pd.DataFrame(v_file_dict[vtype], columns=v_attributes, dtype="object") - for v_attr in v_extra_feats.get(vtype, []): - if v_attr_types[vtype][v_attr] == "MAP": - # I am sorry that this is this ugly... - vertices[vtype][v_attr] = vertices[vtype][v_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) - if primary_id: - id_map = pd.DataFrame({"vid": primary_id.keys(), "primary_id": primary_id.values()}, - dtype="object") - for vtype in vertices: - vertices[vtype] = vertices[vtype].merge(id_map, on="vid") - v_extra_feats[vtype].append("primary_id") - del v_file_dict, v_file - e_file = (line.split(delimiter) for line in e_file.split('\n') if line) - e_file_dict = defaultdict(list) - for line in e_file: - e_file_dict[line[0]].append(line[1:]) - edges = {} - for etype in e_file_dict: - e_attributes = ["source", "target"] + \ - e_in_feats.get(etype, []) + \ - e_out_labels.get(etype, []) + \ - e_extra_feats.get(etype, []) - edges[etype] = pd.DataFrame(e_file_dict[etype], columns=e_attributes, dtype="object") - for e_attr in e_extra_feats.get(etype, []): - if e_attr_types[etype][e_attr] == "MAP": - # I am sorry that this is this ugly... - edges[etype][e_attr] = edges[etype][e_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) - del e_file_dict, e_file - data = (vertices, edges) - else: - raise NotImplementedError - # Convert dataframes into PyG or DGL graphs - if out_format.lower() == "pyg" or out_format.lower() == "dgl": - if vertices is None or edges is None: - raise ValueError( - "Spektral, PyG, or DGL format can only be used with (sub)graph loaders." - ) - try: - import torch - except ImportError: - raise ImportError( - "PyTorch is not installed. Please install it to use PyG or DGL output." - ) - if out_format.lower() == "dgl": - try: - import dgl - mode = "dgl" - except ImportError: - raise ImportError( - "DGL is not installed. Please install DGL to use DGL format." + data[col] = torch.tensor( + attr_df[col].astype("int") ) - elif out_format.lower() == "pyg": - try: - from torch_geometric.data import Data as pygData - from torch_geometric.data import \ - HeteroData as pygHeteroData - from torch_geometric.utils import add_self_loops - mode = "pyg" - except ImportError: - raise ImportError( - "PyG is not installed. Please install PyG to use PyG format." + else: + data[col] = torch.tensor( + attr_df[col].astype(dtype) ) - elif out_format.lower() == "spektral": - if vertices is None or edges is None: - raise ValueError( - "Spektral, PyG, or DGL format can only be used with (sub)graph loaders." - ) - try: - import tensorflow as tf - except ImportError: - raise ImportError( - "Tensorflow is not installed. Please install it to use spektral output." - ) - try: - import scipy - except ImportError: - raise ImportError( - "scipy is not installed. Please install it to use spektral output." - ) - try: - import spektral - mode = "spektral" - except ImportError: - raise ImportError( - "Spektral is not installed. Please install it to use spektral output." - ) - elif out_format.lower() == "dataframe": - if callback_fn: - return callback_fn(data) - else: - return data - else: - raise NotImplementedError + # Reformat as a graph. # Need to have a pair of tables for edges and vertices. + vertices, edges = raw if not is_hetero: # Deal with edgelist first if reindex: @@ -1076,111 +1442,75 @@ def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, else: edgelist = edges[["source", "target"]] - if mode == "dgl" or mode == "pyg": - edgelist = torch.tensor(edgelist.to_numpy().T, dtype=torch.long) - if mode == "dgl": - data = dgl.graph(data=(edgelist[0], edgelist[1])) - if add_self_loop: - data = dgl.add_self_loop(data) - data.extra_data = {} - elif mode == "pyg": - data = pygData() - if add_self_loop: - edgelist = add_self_loops(edgelist)[0] - data["edge_index"] = edgelist - elif mode == "spektral": - n_edges = len(edgelist) - n_vertices = len(vertices) - adjacency_data = [1 for i in range(n_edges)] #spektral adjacency format requires weights for each edge to initialize - adjacency = scipy.sparse.coo_matrix((adjacency_data, (edgelist["tmp_id_x"], edgelist["tmp_id_y"])), shape=(n_vertices, n_vertices)) - if add_self_loop: - adjacency = spektral.utils.add_self_loops(adjacency, value=1) - edge_index = np.stack((adjacency.row, adjacency.col), axis=-1) - data = spektral.data.graph.Graph(A=adjacency) - del edgelist + edgelist = torch.tensor(edgelist.to_numpy().T, dtype=torch.long) + data = dgl.graph(data=(edgelist[0], edgelist[1])) + if add_self_loop: + data = dgl.add_self_loop(data) + data.extra_data = {} # Deal with edge attributes if e_in_feats: add_attributes(e_in_feats, e_attr_types, edges, - data, is_hetero, mode, "edge_feat", "edge") - if mode == "spektral": - edge_data = data["edge_feat"] - edge_index, edge_data = spektral.utils.reorder(edge_index, edge_features=edge_data) - n_edges = len(edge_index) - data["e"] = np.array([[i] for i in edge_data]) #if something breaks when you add self-loops it's here - adjacency_data = [1 for i in range(n_edges)] - data["a"] = scipy.sparse.coo_matrix((adjacency_data, (edge_index[:, 0], edge_index[:, 1])), shape=(n_vertices, n_vertices)) - + data, is_hetero, "edge_feat", "edge") if e_out_labels: add_attributes(e_out_labels, e_attr_types, edges, - data, is_hetero, mode, "edge_label", "edge") + data, is_hetero, "edge_label", "edge") if e_extra_feats: add_sep_attr(e_extra_feats, e_attr_types, edges, - data, is_hetero, mode, "edge") + data, is_hetero, "edge") del edges # Deal with vertex attributes next if v_in_feats: add_attributes(v_in_feats, v_attr_types, vertices, - data, is_hetero, mode, "x", "vertex") + data, is_hetero, "x", "vertex") if v_out_labels: add_attributes(v_out_labels, v_attr_types, vertices, - data, is_hetero, mode, "y", "vertex") + data, is_hetero, "y", "vertex") if v_extra_feats: add_sep_attr(v_extra_feats, v_attr_types, vertices, - data, is_hetero, mode, "vertex") + data, is_hetero, "vertex") del vertices else: # Heterogeneous graph # Deal with edgelist first edgelist = {} if reindex: - id_map = {} - for vtype in vertices: - vertices[vtype]["tmp_id"] = range(len(vertices[vtype])) - id_map[vtype] = vertices[vtype][["vid", "tmp_id"]] - for etype in edges: - source_type = e_attr_types[etype]["FromVertexTypeName"] - target_type = e_attr_types[etype]["ToVertexTypeName"] - if e_attr_types[etype]["IsDirected"] or source_type==target_type: - edges[etype] = edges[etype].merge(id_map[source_type], left_on="source", right_on="vid") - edges[etype].drop(columns=["source", "vid"], inplace=True) - edges[etype] = edges[etype].merge(id_map[target_type], left_on="target", right_on="vid") - edges[etype].drop(columns=["target", "vid"], inplace=True) - edgelist[etype] = edges[etype][["tmp_id_x", "tmp_id_y"]] - else: - subdf1 = edges[etype].merge(id_map[source_type], left_on="source", right_on="vid") - subdf1.drop(columns=["source", "vid"], inplace=True) - subdf1 = subdf1.merge(id_map[target_type], left_on="target", right_on="vid") - subdf1.drop(columns=["target", "vid"], inplace=True) - if len(subdf1) < len(edges[etype]): - subdf2 = edges[etype].merge(id_map[source_type], left_on="target", right_on="vid") - subdf2.drop(columns=["target", "vid"], inplace=True) - subdf2 = subdf2.merge(id_map[target_type], left_on="source", right_on="vid") - subdf2.drop(columns=["source", "vid"], inplace=True) - subdf1 = pd.concat((subdf1, subdf2), ignore_index=True) - edges[etype] = subdf1 - edgelist[etype] = edges[etype][["tmp_id_x", "tmp_id_y"]] + id_map = {} + for vtype in vertices: + vertices[vtype]["tmp_id"] = range(len(vertices[vtype])) + id_map[vtype] = vertices[vtype][["vid", "tmp_id"]] + for etype in edges: + source_type = e_attr_types[etype]["FromVertexTypeName"] + target_type = e_attr_types[etype]["ToVertexTypeName"] + if e_attr_types[etype]["IsDirected"] or source_type==target_type: + edges[etype] = edges[etype].merge(id_map[source_type], left_on="source", right_on="vid") + edges[etype].drop(columns=["source", "vid"], inplace=True) + edges[etype] = edges[etype].merge(id_map[target_type], left_on="target", right_on="vid") + edges[etype].drop(columns=["target", "vid"], inplace=True) + edgelist[etype] = edges[etype][["tmp_id_x", "tmp_id_y"]] + else: + subdf1 = edges[etype].merge(id_map[source_type], left_on="source", right_on="vid") + subdf1.drop(columns=["source", "vid"], inplace=True) + subdf1 = subdf1.merge(id_map[target_type], left_on="target", right_on="vid") + subdf1.drop(columns=["target", "vid"], inplace=True) + if len(subdf1) < len(edges[etype]): + subdf2 = edges[etype].merge(id_map[source_type], left_on="target", right_on="vid") + subdf2.drop(columns=["target", "vid"], inplace=True) + subdf2 = subdf2.merge(id_map[target_type], left_on="source", right_on="vid") + subdf2.drop(columns=["source", "vid"], inplace=True) + subdf1 = pd.concat((subdf1, subdf2), ignore_index=True) + edges[etype] = subdf1 + edgelist[etype] = edges[etype][["tmp_id_x", "tmp_id_y"]] else: for etype in edges: edgelist[etype] = edges[etype][["source", "target"]] for etype in edges: edgelist[etype] = torch.tensor(edgelist[etype].to_numpy().T, dtype=torch.long) - if mode == "dgl": - data = dgl.heterograph({ - (e_attr_types[etype]["FromVertexTypeName"], etype, e_attr_types[etype]["ToVertexTypeName"]): (edgelist[etype][0], edgelist[etype][1]) for etype in edgelist}) - if add_self_loop: - data = dgl.add_self_loop(data) - data.extra_data = {} - elif mode == "pyg": - data = pygHeteroData() - for etype in edgelist: - if add_self_loop: - edgelist[etype] = add_self_loops(edgelist[etype])[0] - data[e_attr_types[etype]["FromVertexTypeName"], - etype, - e_attr_types[etype]["ToVertexTypeName"]].edge_index = edgelist[etype] - elif mode == "spektral": - raise NotImplementedError - del edgelist + + data = dgl.heterograph({ + (e_attr_types[etype]["FromVertexTypeName"], etype, e_attr_types[etype]["ToVertexTypeName"]): (edgelist[etype][0], edgelist[etype][1]) for etype in edgelist}) + if add_self_loop: + data = dgl.add_self_loop(data) + data.extra_data = {} # Deal with edge attributes if e_in_feats: for etype in edges: @@ -1188,21 +1518,21 @@ def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, continue if e_in_feats[etype]: add_attributes(e_in_feats[etype], e_attr_types[etype], edges[etype], - data, is_hetero, mode, "edge_feat", "edge", etype) + data, is_hetero, "edge_feat", "edge", etype) if e_out_labels: for etype in edges: if etype not in e_out_labels: continue if e_out_labels[etype]: add_attributes(e_out_labels[etype], e_attr_types[etype], edges[etype], - data, is_hetero, mode, "edge_label", "edge", etype) + data, is_hetero, "edge_label", "edge", etype) if e_extra_feats: for etype in edges: if etype not in e_extra_feats: continue if e_extra_feats[etype]: add_sep_attr(e_extra_feats[etype], e_attr_types[etype], edges[etype], - data, is_hetero, mode, "edge", etype) + data, is_hetero, "edge", etype) del edges # Deal with vertex attributes next if v_in_feats: @@ -1211,26 +1541,163 @@ def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, continue if v_in_feats[vtype]: add_attributes(v_in_feats[vtype], v_attr_types[vtype], vertices[vtype], - data, is_hetero, mode, "x", "vertex", vtype) + data, is_hetero, "x", "vertex", vtype) if v_out_labels: for vtype in vertices: if vtype not in v_out_labels: continue if v_out_labels[vtype]: add_attributes(v_out_labels[vtype], v_attr_types[vtype], vertices[vtype], - data, is_hetero, mode, "y", "vertex", vtype) + data, is_hetero, "y", "vertex", vtype) if v_extra_feats: for vtype in vertices: if vtype not in v_extra_feats: continue if v_extra_feats[vtype]: add_sep_attr(v_extra_feats[vtype], v_attr_types[vtype], vertices[vtype], - data, is_hetero, mode, "vertex", vtype) + data, is_hetero, "vertex", vtype) + del vertices + return data + + @staticmethod + def _parse_df_to_spektral( + raw: Union[Tuple[pd.DataFrame, pd.DataFrame], Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]], + v_in_feats: Union[list, dict] = [], + v_out_labels: Union[list, dict] = [], + v_extra_feats: Union[list, dict] = [], + v_attr_types: dict = {}, + e_in_feats: Union[list, dict] = [], + e_out_labels: Union[list, dict] = [], + e_extra_feats: Union[list, dict] = [], + e_attr_types: dict = {}, + add_self_loop: bool = False, + reindex: bool = True, + is_hetero: bool = False, + scipy = None, + spektral = None + ) -> "spektral.data.graph.Graph": + """Parse dataframes to Spektral graphs. + """ + def attr_to_tensor( + attributes: list, attr_types: dict, df: pd.DataFrame + ) -> "torch.Tensor": + """Turn multiple columns of a dataframe into a tensor. + """ + x = [] + for col in attributes: + dtype = attr_types[col].lower() + if dtype.startswith("str"): + raise TypeError( + "String type not allowed for input and output features." + ) + if dtype.startswith("list"): + dtype2 = dtype.split(":")[1] + x.append(df[col].str.split(expand=True).to_numpy().astype(dtype2)) + elif dtype.startswith("set") or dtype.startswith("map") or dtype.startswith("date"): + raise NotImplementedError( + "{} type not supported for input and output features yet.".format(dtype)) + elif dtype == "bool": + x.append(df[[col]].astype("int8").to_numpy().astype(dtype)) + elif dtype == "uint": + # PyTorch only supports uint8. Need to convert it to int. + x.append(df[[col]].to_numpy().astype("int")) + else: + x.append(df[[col]].to_numpy().astype(dtype)) + try: + return np.squeeze(np.hstack(x), axis=1) #throws an error if axis isn't 1 + except: + return np.hstack(x) + + def add_attributes(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, + graph, is_hetero: bool, feat_name: str, + target: 'Literal["edge", "vertex"]', vetype: str = None) -> None: + """Add multiple attributes as a single feature to edges or vertices. + """ + data = graph + data[feat_name] = attr_to_tensor(attr_names, attr_types, attr_df) + + def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, + graph, is_hetero: bool, + target: 'Literal["edge", "vertex"]', vetype: str = None) -> None: + """Add each attribute as a single feature to edges or vertices. + """ + data = graph + for col in attr_names: + dtype = attr_types[col].lower() + if dtype.startswith("str") or dtype.startswith("map"): + data[col] = attr_df[col].to_list() + elif dtype.startswith("list"): + dtype2 = dtype.split(":")[1] + if dtype2.startswith("str"): + data[col] = attr_df[col].str.split().to_list() + else: + data[col] = attr_df[col].str.split(expand=True).to_numpy().astype(dtype2) + elif dtype.startswith("set") or dtype.startswith("date"): + raise NotImplementedError( + "{} type not supported for extra features yet.".format(dtype)) + elif dtype == "bool": + data[col] = attr_df[col].astype("int8").astype(dtype) + else: + data[col] = attr_df[col].astype(dtype) + + # Reformat as a graph. + # Need to have a pair of tables for edges and vertices. + vertices, edges = raw + if not is_hetero: + # Deal with edgelist first + if reindex: + vertices["tmp_id"] = range(len(vertices)) + id_map = vertices[["vid", "tmp_id"]] + edges = edges.merge(id_map, left_on="source", right_on="vid") + edges.drop(columns=["source", "vid"], inplace=True) + edges = edges.merge(id_map, left_on="target", right_on="vid") + edges.drop(columns=["target", "vid"], inplace=True) + edgelist = edges[["tmp_id_x", "tmp_id_y"]] + else: + edgelist = edges[["source", "target"]] + n_edges = len(edgelist) + n_vertices = len(vertices) + adjacency_data = [1 for i in range(n_edges)] #spektral adjacency format requires weights for each edge to initialize + adjacency = scipy.sparse.coo_matrix((adjacency_data, (edgelist["tmp_id_x"], edgelist["tmp_id_y"])), shape=(n_vertices, n_vertices)) + if add_self_loop: + adjacency = spektral.utils.add_self_loops(adjacency, value=1) + edge_index = np.stack((adjacency.row, adjacency.col), axis=-1) + data = spektral.data.graph.Graph(A=adjacency) + del edgelist + # Deal with edge attributes + if e_in_feats: + add_attributes(e_in_feats, e_attr_types, edges, + data, is_hetero, "edge_feat", "edge") + edge_data = data["edge_feat"] + edge_index, edge_data = spektral.utils.reorder(edge_index, edge_features=edge_data) + n_edges = len(edge_index) + data["e"] = np.array([[i] for i in edge_data]) #if something breaks when you add self-loops it's here + adjacency_data = [1 for i in range(n_edges)] + data["a"] = scipy.sparse.coo_matrix((adjacency_data, (edge_index[:, 0], edge_index[:, 1])), shape=(n_vertices, n_vertices)) + + if e_out_labels: + add_attributes(e_out_labels, e_attr_types, edges, + data, is_hetero, "edge_label", "edge") + if e_extra_feats: + add_sep_attr(e_extra_feats, e_attr_types, edges, + data, is_hetero, "edge") + del edges + # Deal with vertex attributes next + if v_in_feats: + add_attributes(v_in_feats, v_attr_types, vertices, + data, is_hetero, "x", "vertex") + if v_out_labels: + add_attributes(v_out_labels, v_attr_types, vertices, + data, is_hetero, "y", "vertex") + if v_extra_feats: + add_sep_attr(v_extra_feats, v_attr_types, vertices, + data, is_hetero, "vertex") del vertices - if callback_fn: - return callback_fn(data) else: - return data + # Heterogeneous graph + # Deal with edgelist first + raise NotImplementedError + return data def _start_request(self, out_tuple: bool, resp_type: str): # If using kafka diff --git a/tests/test_gds_BaseLoader.py b/tests/test_gds_BaseLoader.py index d2987c8d..af75cec6 100644 --- a/tests/test_gds_BaseLoader.py +++ b/tests/test_gds_BaseLoader.py @@ -146,12 +146,10 @@ def test_read_vertex(self): raw = "99|1 0 0 1 |1|0|1\n8|1 0 0 1 |1|1|1\n" read_task_q.put(raw) read_task_q.put(None) - self.loader._read_data( + self.loader._read_vertex_data( exit_event, read_task_q, data_q, - "vertex", - "dataframe", ["x"], ["y"], ["train_mask", "is_seed"], @@ -176,12 +174,10 @@ def test_read_vertex_callback(self): raw = "99|1 0 0 1 |1|0|1\n8|1 0 0 1 |1|1|1\n" read_task_q.put(raw) read_task_q.put(None) - self.loader._read_data( + self.loader._read_vertex_data( exit_event, read_task_q, data_q, - "vertex", - "dataframe", ["x"], ["y"], ["train_mask", "is_seed"], @@ -199,16 +195,10 @@ def test_read_edge(self): raw = "1|2|0.1|2021|1|0\n2|1|1.5|2020|0|1\n" read_task_q.put(raw) read_task_q.put(None) - self.loader._read_data( + self.loader._read_edge_data( exit_event, read_task_q, data_q, - "edge", - "dataframe", - [], - [], - [], - {}, ["x", "time"], ["y"], ["is_train"], @@ -233,16 +223,10 @@ def test_read_edge_callback(self): raw = "1|2|0.1|2021|1|0\n2|1|1.5|2020|0|1\n" read_task_q.put(raw) read_task_q.put(None) - self.loader._read_data( + self.loader._read_edge_data( exit_event, read_task_q, data_q, - "edge", - "dataframe", - [], - [], - [], - {}, ["x", "time"], ["y"], ["is_train"], @@ -264,11 +248,10 @@ def test_read_graph_out_df(self): ) read_task_q.put(raw) read_task_q.put(None) - self.loader._read_data( + self.loader._read_graph_data( exit_event, read_task_q, data_q, - "graph", "dataframe", ["x"], ["y"], @@ -309,11 +292,10 @@ def test_read_graph_out_df_callback(self): ) read_task_q.put(raw) read_task_q.put(None) - self.loader._read_data( + self.loader._read_graph_data( exit_event, read_task_q, data_q, - "graph", "dataframe", ["x"], ["y"], @@ -341,11 +323,10 @@ def test_read_graph_out_pyg(self): ) read_task_q.put(raw) read_task_q.put(None) - self.loader._read_data( + self.loader._read_graph_data( exit_event, read_task_q, data_q, - "graph", "pyg", ["x"], ["y"], @@ -391,11 +372,10 @@ def test_read_graph_out_dgl(self): ) read_task_q.put(raw) read_task_q.put(None) - self.loader._read_data( + self.loader._read_graph_data( exit_event, read_task_q, data_q, - "graph", "dgl", ["x"], ["y"], @@ -441,11 +421,10 @@ def test_read_graph_parse_error(self): ) read_task_q.put(raw) read_task_q.put(None) - self.loader._read_data( + self.loader._read_graph_data( exit_event, read_task_q, data_q, - "graph", "dgl", ["x"], ["y"], @@ -473,11 +452,10 @@ def test_read_graph_no_attr(self): raw = ("99|1\n8|0\n", "99|8\n8|99\n") read_task_q.put(raw) read_task_q.put(None) - self.loader._read_data( + self.loader._read_graph_data( exit_event, read_task_q, data_q, - "graph", "pyg", [], [], @@ -512,11 +490,10 @@ def test_read_graph_no_edge(self): ) read_task_q.put(raw) read_task_q.put(None) - self.loader._read_data( + self.loader._read_graph_data( exit_event, read_task_q, data_q, - "graph", "pyg", ["x"], ["y"], @@ -558,11 +535,10 @@ def test_read_hetero_graph_out_pyg(self): ) read_task_q.put(raw) read_task_q.put(None) - self.loader._read_data( + self.loader._read_graph_data( exit_event, read_task_q, data_q, - "graph", "pyg", {"People": ["x"], "Company": ["x"]}, {"People": ["y"]}, @@ -639,11 +615,10 @@ def test_read_hetero_graph_no_attr(self): ) read_task_q.put(raw) read_task_q.put(None) - self.loader._read_data( + self.loader._read_graph_data( exit_event, read_task_q, data_q, - "graph", "pyg", {"People": [], "Company": []}, {"People": [], "Company": []}, @@ -705,11 +680,10 @@ def test_read_hetero_graph_no_edge(self): ) read_task_q.put(raw) read_task_q.put(None) - self.loader._read_data( + self.loader._read_graph_data( exit_event, read_task_q, data_q, - "graph", "pyg", {"People": ["x"], "Company": ["x"]}, {"People": ["y"]}, @@ -776,11 +750,10 @@ def test_read_hetero_graph_out_dgl(self): ) read_task_q.put(raw) read_task_q.put(None) - self.loader._read_data( + self.loader._read_graph_data( exit_event, read_task_q, data_q, - "graph", "dgl", {"People": ["x"], "Company": ["x"]}, {"People": ["y"]}, @@ -859,11 +832,10 @@ def test_read_bool_label(self): ) read_task_q.put(raw) read_task_q.put(None) - self.loader._read_data( + self.loader._read_graph_data( exit_event, read_task_q, data_q, - "graph", "pyg", ["x"], ["y"], From 324fea0ea9d070bf48361e88986260678e7f1c73 Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Mon, 9 Oct 2023 12:30:39 -0700 Subject: [PATCH 02/36] refactor(BaseLoader): consolidate vertex and edge parse --- pyTigerGraph/gds/dataloaders.py | 141 +++++++++++++++----------------- 1 file changed, 64 insertions(+), 77 deletions(-) diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index 63b15120..cdf6ff62 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -728,7 +728,22 @@ def _read_graph_data( primary_id = {}, is_hetero = is_hetero, ) - if out_format == "pyg": + if out_format == "dataframe" or out_format == "df": + vertices, edges = data + if not is_hetero: + for column in vertices.columns: + vertices[column] = pd.to_numeric(vertices[column], errors="ignore") + for column in edges.columns: + edges[column] = pd.to_numeric(edges[column], errors="ignore") + else: + for key in vertices: + for column in vertices[key].columns: + vertices[key][column] = pd.to_numeric(vertices[key][column], errors="ignore") + for key in edges: + for column in edges[key].columns: + edges[key][column] = pd.to_numeric(edges[key][column], errors="ignore") + data = (vertices, edges) + elif out_format == "pyg": data = BaseLoader._parse_df_to_pyg( raw = data, v_in_feats = v_in_feats, @@ -779,8 +794,6 @@ def _read_graph_data( scipy = scipy, spektral = spektral ) - elif out_format == "dataframe" or out_format == "df": - pass else: raise NotImplementedError if callback_fn: @@ -822,6 +835,13 @@ def _read_vertex_data( delimiter = delimiter, is_hetero = is_hetero ) + if not is_hetero: + for column in data.columns: + data[column] = pd.to_numeric(data[column], errors="ignore") + else: + for key in data: + for column in data[key].columns: + data[key][column] = pd.to_numeric(data[key][column], errors="ignore") if callback_fn: data = callback_fn(data) out_q.put(data) @@ -861,6 +881,13 @@ def _read_edge_data( delimiter = delimiter, is_hetero = is_hetero ) + if not is_hetero: + for column in data.columns: + data[column] = pd.to_numeric(data[column], errors="ignore") + else: + for key in data: + for column in data[key].columns: + data[key][column] = pd.to_numeric(data[key][column], errors="ignore") if callback_fn: data = callback_fn(data) out_q.put(data) @@ -888,10 +915,8 @@ def _parse_vertex_data( # String of vertices in format vid,v_in_feats,v_out_labels,v_extra_feats v_attributes = ["vid"] + v_in_feats + v_out_labels + v_extra_feats v_file = (line.strip().split(delimiter) for line in raw.splitlines()) - data = pd.DataFrame(v_file, columns=v_attributes) - for column in data.columns: - data[column] = pd.to_numeric(data[column], errors="ignore") - for v_attr in v_attributes: + data = pd.DataFrame(v_file, columns=v_attributes, dtype="object") + for v_attr in v_extra_feats: if v_attr_types.get(v_attr, "") == "MAP": # I am sorry that this is this ugly... data[v_attr] = data[v_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) @@ -901,18 +926,17 @@ def _parse_vertex_data( v_file_dict = defaultdict(list) for line in v_file: v_file_dict[line[0]].append(line[1:]) - vertices = {} + data = {} for vtype in v_file_dict: v_attributes = ["vid"] + \ v_in_feats.get(vtype, []) + \ v_out_labels.get(vtype, []) + \ v_extra_feats.get(vtype, []) - vertices[vtype] = pd.DataFrame(v_file_dict[vtype], columns=v_attributes) + data[vtype] = pd.DataFrame(v_file_dict[vtype], columns=v_attributes, dtype="object") for v_attr in v_extra_feats.get(vtype, []): if v_attr_types[vtype][v_attr] == "MAP": # I am sorry that this is this ugly... - vertices[vtype][v_attr] = vertices[vtype][v_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) - data = vertices + data[vtype][v_attr] = data[vtype][v_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) return data @staticmethod @@ -934,10 +958,8 @@ def _parse_edge_data( #file = "\n".join(x for x in raw.split("\n") if x.strip()) #data = pd.read_table(io.StringIO(file), header=None, names=e_attributes, sep=delimiter) e_file = (line.strip().split(delimiter) for line in raw.splitlines()) - data = pd.DataFrame(e_file, columns=e_attributes) - for column in data.columns: - data[column] = pd.to_numeric(data[column], errors="ignore") - for e_attr in e_attributes: + data = pd.DataFrame(e_file, columns=e_attributes, dtype="object") + for e_attr in e_extra_feats: if e_attr_types.get(e_attr, "") == "MAP": # I am sorry that this is this ugly... data[e_attr] = data[e_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) @@ -947,20 +969,17 @@ def _parse_edge_data( e_file_dict = defaultdict(list) for line in e_file: e_file_dict[line[0]].append(line[1:]) - edges = {} + data = {} for etype in e_file_dict: e_attributes = ["source", "target"] + \ e_in_feats.get(etype, []) + \ e_out_labels.get(etype, []) + \ e_extra_feats.get(etype, []) - edges[etype] = pd.DataFrame(e_file_dict[etype], columns=e_attributes) + data[etype] = pd.DataFrame(e_file_dict[etype], columns=e_attributes, dtype="object") for e_attr in e_extra_feats.get(etype, []): if e_attr_types[etype][e_attr] == "MAP": # I am sorry that this is this ugly... - edges[etype][e_attr] = edges[etype][e_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) - del e_file_dict, e_file - data = edges - + data[etype][e_attr] = data[etype][e_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) return data @staticmethod @@ -983,65 +1002,33 @@ def _parse_graph_data_to_df( # Read in vertex and edge CSVs as dataframes # A pair of in-memory CSVs (vertex, edge) v_file, e_file = raw - if not is_hetero: - v_attributes = ["vid"] + v_in_feats + v_out_labels + v_extra_feats - e_attributes = ["source", "target"] + e_in_feats + e_out_labels + e_extra_feats - v_file = (line.split(delimiter) for line in v_file.splitlines()) - vertices = pd.DataFrame(v_file, columns=v_attributes) - for column in vertices.columns: - vertices[column] = pd.to_numeric(vertices[column], errors="ignore") - for v_attr in v_extra_feats: - if v_attr_types[v_attr] == "MAP": - # I am sorry that this is this ugly... - vertices[v_attr] = vertices[v_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) - if primary_id: - id_map = pd.DataFrame({"vid": primary_id.keys(), "primary_id": primary_id.values()}) - vertices = vertices.merge(id_map.astype({"vid": vertices["vid"].dtype}), on="vid") + vertices = BaseLoader._parse_vertex_data( + raw = v_file, + v_in_feats = v_in_feats, + v_out_labels = v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + delimiter = delimiter, + is_hetero = is_hetero) + if primary_id: + id_map = pd.DataFrame({"vid": primary_id.keys(), "primary_id": primary_id.values()}, + dtype="object") + if not is_hetero: + vertices = vertices.merge(id_map, on="vid") v_extra_feats.append("primary_id") - e_file = (line.split(delimiter) for line in e_file.splitlines()) - edges = pd.DataFrame(e_file, columns=e_attributes) - for column in edges.columns: - edges[column] = pd.to_numeric(edges[column], errors="ignore") - for e_attr in e_attributes: - if e_attr_types.get(e_attr, "") == "MAP": - # I am sorry that this is this ugly... - edges[e_attr] = edges[e_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) - else: - v_file = (line.split(delimiter) for line in v_file.splitlines()) - v_file_dict = defaultdict(list) - for line in v_file: - v_file_dict[line[0]].append(line[1:]) - vertices = {} - for vtype in v_file_dict: - v_attributes = ["vid"] + \ - v_in_feats.get(vtype, []) + \ - v_out_labels.get(vtype, []) + \ - v_extra_feats.get(vtype, []) - vertices[vtype] = pd.DataFrame(v_file_dict[vtype], columns=v_attributes, dtype="object") - for v_attr in v_extra_feats.get(vtype, []): - if v_attr_types[vtype][v_attr] == "MAP": - # I am sorry that this is this ugly... - vertices[vtype][v_attr] = vertices[vtype][v_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) - if primary_id: - id_map = pd.DataFrame({"vid": primary_id.keys(), "primary_id": primary_id.values()}, - dtype="object") + else: for vtype in vertices: vertices[vtype] = vertices[vtype].merge(id_map, on="vid") - e_file = (line.split(delimiter) for line in e_file.splitlines()) - e_file_dict = defaultdict(list) - for line in e_file: - e_file_dict[line[0]].append(line[1:]) - edges = {} - for etype in e_file_dict: - e_attributes = ["source", "target"] + \ - e_in_feats.get(etype, []) + \ - e_out_labels.get(etype, []) + \ - e_extra_feats.get(etype, []) - edges[etype] = pd.DataFrame(e_file_dict[etype], columns=e_attributes, dtype="object") - for e_attr in e_extra_feats.get(etype, []): - if e_attr_types[etype][e_attr] == "MAP": - # I am sorry that this is this ugly... - edges[etype][e_attr] = edges[etype][e_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) + v_extra_feats[vtype].append("primary_id") + edges = BaseLoader._parse_edge_data( + raw = e_file, + e_in_feats = e_in_feats, + e_out_labels = e_out_labels, + e_extra_feats = e_extra_feats, + e_attr_types = e_attr_types, + delimiter = delimiter, + is_hetero = is_hetero + ) return (vertices, edges) @staticmethod From 441a0b3197d3b2679ad57dd95a41394efabdaba1 Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Mon, 9 Oct 2023 14:10:10 -0700 Subject: [PATCH 03/36] refactor(BaseLoader): consolidate attr_to_tensor and reindex --- pyTigerGraph/gds/dataloaders.py | 270 ++++++++++---------------------- 1 file changed, 84 insertions(+), 186 deletions(-) diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index cdf6ff62..bd3212b8 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -1031,6 +1031,71 @@ def _parse_graph_data_to_df( ) return (vertices, edges) + @staticmethod + def _attributes_to_np( + attributes: list, attr_types: dict, df: pd.DataFrame + ) -> np.ndarray: + """Turn multiple columns of a dataframe into a numpy array. + """ + x = [] + for col in attributes: + dtype = attr_types[col].lower() + if dtype.startswith("str"): + raise TypeError( + "String type not allowed for input and output features." + ) + if dtype.startswith("list"): + dtype2 = dtype.split(":")[1] + x.append(df[col].str.split(expand=True).to_numpy().astype(dtype2)) + elif dtype.startswith("set") or dtype.startswith("map") or dtype.startswith("date"): + raise NotImplementedError( + "{} type not supported for input and output features yet.".format(dtype)) + elif dtype == "bool": + x.append(df[[col]].astype("int8").to_numpy().astype(dtype)) + elif dtype == "uint": + # PyTorch only supports uint8. Need to convert it to int. + x.append(df[[col]].to_numpy().astype("int")) + else: + x.append(df[[col]].to_numpy().astype(dtype)) + return np.hstack(x) + + @staticmethod + def _get_edgelist( + vertices: Union[pd.DataFrame, Dict[str, pd.DataFrame]], + edges: Union[pd.DataFrame, Dict[str, pd.DataFrame]], + is_hetero: bool, + e_attr_types: dict = {} + ): + if not is_hetero: + vertices["tmp_id"] = range(len(vertices)) + id_map = vertices[["vid", "tmp_id"]] + edgelist = edges[["source", "target"]].merge(id_map, left_on="source", right_on="vid") + edgelist = edgelist.merge(id_map, left_on="target", right_on="vid") + edgelist = edgelist[["tmp_id_x", "tmp_id_y"]] + else: + edgelist = {} + id_map = {} + for vtype in vertices: + vertices[vtype]["tmp_id"] = range(len(vertices[vtype])) + id_map[vtype] = vertices[vtype][["vid", "tmp_id"]] + for etype in edges: + source_type = e_attr_types[etype]["FromVertexTypeName"] + target_type = e_attr_types[etype]["ToVertexTypeName"] + if e_attr_types[etype]["IsDirected"] or source_type==target_type: + edgelist[etype] = edges[etype][["source", "target"]].merge(id_map[source_type], left_on="source", right_on="vid") + edgelist[etype] = edgelist[etype].merge(id_map[target_type], left_on="target", right_on="vid") + edgelist[etype] = edgelist[etype][["tmp_id_x", "tmp_id_y"]] + else: + subdf1 = edges[etype].merge(id_map[source_type], left_on="source", right_on="vid") + subdf1 = subdf1.merge(id_map[target_type], left_on="target", right_on="vid") + if len(subdf1) < len(edges[etype]): + subdf2 = edges[etype].merge(id_map[source_type], left_on="target", right_on="vid") + subdf2 = subdf2.merge(id_map[target_type], left_on="source", right_on="vid") + subdf1 = pd.concat((subdf1, subdf2), ignore_index=True) + edges[etype] = subdf1 + edgelist[etype] = edges[etype][["tmp_id_x", "tmp_id_y"]] + return edgelist + @staticmethod def _parse_df_to_pyg( raw: Union[Tuple[pd.DataFrame, pd.DataFrame], Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]], @@ -1050,33 +1115,6 @@ def _parse_df_to_pyg( ) -> Union["pyg.data.Data", "pyg.data.HeteroData"]: """Parse dataframes to PyG graphs. """ - def attr_to_tensor( - attributes: list, attr_types: dict, df: pd.DataFrame - ) -> "torch.Tensor": - """Turn multiple columns of a dataframe into a tensor. - """ - x = [] - for col in attributes: - dtype = attr_types[col].lower() - if dtype.startswith("str"): - raise TypeError( - "String type not allowed for input and output features." - ) - if dtype.startswith("list"): - dtype2 = dtype.split(":")[1] - x.append(df[col].str.split(expand=True).to_numpy().astype(dtype2)) - elif dtype.startswith("set") or dtype.startswith("map") or dtype.startswith("date"): - raise NotImplementedError( - "{} type not supported for input and output features yet.".format(dtype)) - elif dtype == "bool": - x.append(df[[col]].astype("int8").to_numpy().astype(dtype)) - elif dtype == "uint": - # PyTorch only supports uint8. Need to convert it to int. - x.append(df[[col]].to_numpy().astype("int")) - else: - x.append(df[[col]].to_numpy().astype(dtype)) - return torch.tensor(np.hstack(x)).squeeze(dim=1) - def add_attributes(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, graph, is_hetero: bool, feat_name: str, target: 'Literal["edge", "vertex"]', vetype: str = None) -> None: @@ -1094,7 +1132,8 @@ def add_attributes(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, data = graph[vetype] else: data = graph - data[feat_name] = attr_to_tensor(attr_names, attr_types, attr_df) + array = BaseLoader._attributes_to_np(attr_names, attr_types, attr_df) + data[feat_name] = torch.tensor(array).squeeze(dim=1) def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, graph, is_hetero: bool, @@ -1150,19 +1189,9 @@ def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, # Reformat as a graph. # Need to have a pair of tables for edges and vertices. vertices, edges = raw + edgelist = BaseLoader._get_edgelist(vertices, edges, is_hetero, e_attr_types) if not is_hetero: # Deal with edgelist first - if reindex: - vertices["tmp_id"] = range(len(vertices)) - id_map = vertices[["vid", "tmp_id"]] - edges = edges.merge(id_map, left_on="source", right_on="vid") - edges.drop(columns=["source", "vid"], inplace=True) - edges = edges.merge(id_map, left_on="target", right_on="vid") - edges.drop(columns=["target", "vid"], inplace=True) - edgelist = edges[["tmp_id_x", "tmp_id_y"]] - else: - edgelist = edges[["source", "target"]] - edgelist = torch.tensor(edgelist.to_numpy().T, dtype=torch.long) if add_self_loop: edgelist = pyg.utils.add_self_loops(edgelist)[0] @@ -1178,6 +1207,7 @@ def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, if e_extra_feats: add_sep_attr(e_extra_feats, e_attr_types, edges, data, is_hetero, "edge") + del edges # Deal with vertex attributes next if v_in_feats: add_attributes(v_in_feats, v_attr_types, vertices, @@ -1188,40 +1218,10 @@ def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, if v_extra_feats: add_sep_attr(v_extra_feats, v_attr_types, vertices, data, is_hetero, "vertex") + del vertices else: # Heterogeneous graph # Deal with edgelist first - edgelist = {} - if reindex: - id_map = {} - for vtype in vertices: - vertices[vtype]["tmp_id"] = range(len(vertices[vtype])) - id_map[vtype] = vertices[vtype][["vid", "tmp_id"]] - for etype in edges: - source_type = e_attr_types[etype]["FromVertexTypeName"] - target_type = e_attr_types[etype]["ToVertexTypeName"] - if e_attr_types[etype]["IsDirected"] or source_type==target_type: - edges[etype] = edges[etype].merge(id_map[source_type], left_on="source", right_on="vid") - edges[etype].drop(columns=["source", "vid"], inplace=True) - edges[etype] = edges[etype].merge(id_map[target_type], left_on="target", right_on="vid") - edges[etype].drop(columns=["target", "vid"], inplace=True) - edgelist[etype] = edges[etype][["tmp_id_x", "tmp_id_y"]] - else: - subdf1 = edges[etype].merge(id_map[source_type], left_on="source", right_on="vid") - subdf1.drop(columns=["source", "vid"], inplace=True) - subdf1 = subdf1.merge(id_map[target_type], left_on="target", right_on="vid") - subdf1.drop(columns=["target", "vid"], inplace=True) - if len(subdf1) < len(edges[etype]): - subdf2 = edges[etype].merge(id_map[source_type], left_on="target", right_on="vid") - subdf2.drop(columns=["target", "vid"], inplace=True) - subdf2 = subdf2.merge(id_map[target_type], left_on="source", right_on="vid") - subdf2.drop(columns=["source", "vid"], inplace=True) - subdf1 = pd.concat((subdf1, subdf2), ignore_index=True) - edges[etype] = subdf1 - edgelist[etype] = edges[etype][["tmp_id_x", "tmp_id_y"]] - else: - for etype in edges: - edgelist[etype] = edges[etype][["source", "target"]] for etype in edges: edgelist[etype] = torch.tensor(edgelist[etype].to_numpy().T, dtype=torch.long) if add_self_loop: @@ -1252,7 +1252,8 @@ def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, continue if e_extra_feats[etype]: add_sep_attr(e_extra_feats[etype], e_attr_types[etype], edges[etype], - data, is_hetero, "edge", etype) + data, is_hetero, "edge", etype) + del edges # Deal with vertex attributes next if v_in_feats: for vtype in vertices: @@ -1275,6 +1276,7 @@ def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, if v_extra_feats[vtype]: add_sep_attr(v_extra_feats[vtype], v_attr_types[vtype], vertices[vtype], data, is_hetero, "vertex", vtype) + del vertices return data @staticmethod @@ -1296,33 +1298,6 @@ def _parse_df_to_dgl( ) -> Union["dgl.graph", "dgl.heterograph"]: """Parse dataframes to PyG graphs. """ - def attr_to_tensor( - attributes: list, attr_types: dict, df: pd.DataFrame - ) -> "torch.Tensor": - """Turn multiple columns of a dataframe into a tensor. - """ - x = [] - for col in attributes: - dtype = attr_types[col].lower() - if dtype.startswith("str"): - raise TypeError( - "String type not allowed for input and output features." - ) - if dtype.startswith("list"): - dtype2 = dtype.split(":")[1] - x.append(df[col].str.split(expand=True).to_numpy().astype(dtype2)) - elif dtype.startswith("set") or dtype.startswith("map") or dtype.startswith("date"): - raise NotImplementedError( - "{} type not supported for input and output features yet.".format(dtype)) - elif dtype == "bool": - x.append(df[[col]].astype("int8").to_numpy().astype(dtype)) - elif dtype == "uint": - # PyTorch only supports uint8. Need to convert it to int. - x.append(df[[col]].to_numpy().astype("int")) - else: - x.append(df[[col]].to_numpy().astype(dtype)) - return torch.tensor(np.hstack(x)).squeeze(dim=1) - def add_attributes(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, graph, is_hetero: bool, feat_name: str, target: 'Literal["edge", "vertex"]', vetype: str = None) -> None: @@ -1341,8 +1316,8 @@ def add_attributes(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, data = graph.edata elif target == "vertex": data = graph.ndata - - data[feat_name] = attr_to_tensor(attr_names, attr_types, attr_df) + array = BaseLoader._attributes_to_np(attr_names, attr_types, attr_df) + data[feat_name] = torch.tensor(array).squeeze(dim=1) def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, graph, is_hetero: bool, @@ -1416,19 +1391,9 @@ def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, # Reformat as a graph. # Need to have a pair of tables for edges and vertices. vertices, edges = raw + edgelist = BaseLoader._get_edgelist(vertices, edges, is_hetero, e_attr_types) if not is_hetero: # Deal with edgelist first - if reindex: - vertices["tmp_id"] = range(len(vertices)) - id_map = vertices[["vid", "tmp_id"]] - edges = edges.merge(id_map, left_on="source", right_on="vid") - edges.drop(columns=["source", "vid"], inplace=True) - edges = edges.merge(id_map, left_on="target", right_on="vid") - edges.drop(columns=["target", "vid"], inplace=True) - edgelist = edges[["tmp_id_x", "tmp_id_y"]] - else: - edgelist = edges[["source", "target"]] - edgelist = torch.tensor(edgelist.to_numpy().T, dtype=torch.long) data = dgl.graph(data=(edgelist[0], edgelist[1])) if add_self_loop: @@ -1459,40 +1424,8 @@ def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, else: # Heterogeneous graph # Deal with edgelist first - edgelist = {} - if reindex: - id_map = {} - for vtype in vertices: - vertices[vtype]["tmp_id"] = range(len(vertices[vtype])) - id_map[vtype] = vertices[vtype][["vid", "tmp_id"]] - for etype in edges: - source_type = e_attr_types[etype]["FromVertexTypeName"] - target_type = e_attr_types[etype]["ToVertexTypeName"] - if e_attr_types[etype]["IsDirected"] or source_type==target_type: - edges[etype] = edges[etype].merge(id_map[source_type], left_on="source", right_on="vid") - edges[etype].drop(columns=["source", "vid"], inplace=True) - edges[etype] = edges[etype].merge(id_map[target_type], left_on="target", right_on="vid") - edges[etype].drop(columns=["target", "vid"], inplace=True) - edgelist[etype] = edges[etype][["tmp_id_x", "tmp_id_y"]] - else: - subdf1 = edges[etype].merge(id_map[source_type], left_on="source", right_on="vid") - subdf1.drop(columns=["source", "vid"], inplace=True) - subdf1 = subdf1.merge(id_map[target_type], left_on="target", right_on="vid") - subdf1.drop(columns=["target", "vid"], inplace=True) - if len(subdf1) < len(edges[etype]): - subdf2 = edges[etype].merge(id_map[source_type], left_on="target", right_on="vid") - subdf2.drop(columns=["target", "vid"], inplace=True) - subdf2 = subdf2.merge(id_map[target_type], left_on="source", right_on="vid") - subdf2.drop(columns=["source", "vid"], inplace=True) - subdf1 = pd.concat((subdf1, subdf2), ignore_index=True) - edges[etype] = subdf1 - edgelist[etype] = edges[etype][["tmp_id_x", "tmp_id_y"]] - else: - for etype in edges: - edgelist[etype] = edges[etype][["source", "target"]] for etype in edges: edgelist[etype] = torch.tensor(edgelist[etype].to_numpy().T, dtype=torch.long) - data = dgl.heterograph({ (e_attr_types[etype]["FromVertexTypeName"], etype, e_attr_types[etype]["ToVertexTypeName"]): (edgelist[etype][0], edgelist[etype][1]) for etype in edgelist}) if add_self_loop: @@ -1564,44 +1497,19 @@ def _parse_df_to_spektral( spektral = None ) -> "spektral.data.graph.Graph": """Parse dataframes to Spektral graphs. - """ - def attr_to_tensor( - attributes: list, attr_types: dict, df: pd.DataFrame - ) -> "torch.Tensor": - """Turn multiple columns of a dataframe into a tensor. - """ - x = [] - for col in attributes: - dtype = attr_types[col].lower() - if dtype.startswith("str"): - raise TypeError( - "String type not allowed for input and output features." - ) - if dtype.startswith("list"): - dtype2 = dtype.split(":")[1] - x.append(df[col].str.split(expand=True).to_numpy().astype(dtype2)) - elif dtype.startswith("set") or dtype.startswith("map") or dtype.startswith("date"): - raise NotImplementedError( - "{} type not supported for input and output features yet.".format(dtype)) - elif dtype == "bool": - x.append(df[[col]].astype("int8").to_numpy().astype(dtype)) - elif dtype == "uint": - # PyTorch only supports uint8. Need to convert it to int. - x.append(df[[col]].to_numpy().astype("int")) - else: - x.append(df[[col]].to_numpy().astype(dtype)) - try: - return np.squeeze(np.hstack(x), axis=1) #throws an error if axis isn't 1 - except: - return np.hstack(x) - + """ def add_attributes(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, graph, is_hetero: bool, feat_name: str, target: 'Literal["edge", "vertex"]', vetype: str = None) -> None: """Add multiple attributes as a single feature to edges or vertices. """ data = graph - data[feat_name] = attr_to_tensor(attr_names, attr_types, attr_df) + array = BaseLoader._attributes_to_np(attr_names, attr_types, attr_df) + try: + array = np.squeeze(array, axis=1) + except: + pass + data[feat_name] = array def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, graph, is_hetero: bool, @@ -1630,18 +1538,9 @@ def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, # Reformat as a graph. # Need to have a pair of tables for edges and vertices. vertices, edges = raw + edgelist = BaseLoader._get_edgelist(vertices, edges, is_hetero, e_attr_types) if not is_hetero: # Deal with edgelist first - if reindex: - vertices["tmp_id"] = range(len(vertices)) - id_map = vertices[["vid", "tmp_id"]] - edges = edges.merge(id_map, left_on="source", right_on="vid") - edges.drop(columns=["source", "vid"], inplace=True) - edges = edges.merge(id_map, left_on="target", right_on="vid") - edges.drop(columns=["target", "vid"], inplace=True) - edgelist = edges[["tmp_id_x", "tmp_id_y"]] - else: - edgelist = edges[["source", "target"]] n_edges = len(edgelist) n_vertices = len(vertices) adjacency_data = [1 for i in range(n_edges)] #spektral adjacency format requires weights for each edge to initialize @@ -1661,7 +1560,6 @@ def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, data["e"] = np.array([[i] for i in edge_data]) #if something breaks when you add self-loops it's here adjacency_data = [1 for i in range(n_edges)] data["a"] = scipy.sparse.coo_matrix((adjacency_data, (edge_index[:, 0], edge_index[:, 1])), shape=(n_vertices, n_vertices)) - if e_out_labels: add_attributes(e_out_labels, e_attr_types, edges, data, is_hetero, "edge_label", "edge") From 65cca461010335a4332d1f3eeefd7011cde4a072 Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Tue, 10 Oct 2023 12:46:44 -0700 Subject: [PATCH 04/36] feat(BaseLoader): handle one data point per msg --- pyTigerGraph/gds/dataloaders.py | 116 ++-- tests/test_gds_BaseLoader.py | 946 ++++++++++++++++++-------------- 2 files changed, 600 insertions(+), 462 deletions(-) diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index bd3212b8..51f1dd9d 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -18,6 +18,7 @@ from threading import Event, Thread from time import sleep import pickle +import random from typing import TYPE_CHECKING, Any, Iterator, NoReturn, Tuple, Union, Callable, List, Dict #import re @@ -562,7 +563,6 @@ def _request_graph_rest( # Put raw data into reading queue for i in resp: read_task_q.put((i["vertex_batch"], i["edge_batch"])) - read_task_q.put(None) @staticmethod def _request_unimode_rest( @@ -579,7 +579,6 @@ def _request_unimode_rest( # Put raw data into reading queue for i in resp: read_task_q.put(i["data_batch"]) - read_task_q.put(None) @staticmethod def _download_graph_kafka( @@ -622,7 +621,6 @@ def _download_graph_kafka( raise ValueError( "Unrecognized key {} for messages in kafka".format(key) ) - read_task_q.put(None) def _download_unimode_kafka( exit_event: Event, @@ -642,13 +640,14 @@ def _download_unimode_kafka( for message in msgs: read_task_q.put(message.value.decode("utf-8")) delivered_batch += 1 - read_task_q.put(None) @staticmethod def _read_graph_data( exit_event: Event, in_q: Queue, out_q: Queue, + batch_size: int, + shuffle: bool = False, out_format: str = "dataframe", v_in_feats: Union[list, dict] = [], v_out_labels: Union[list, dict] = [], @@ -660,7 +659,6 @@ def _read_graph_data( e_attr_types: dict = {}, add_self_loop: bool = False, delimiter: str = "|", - reindex: bool = True, is_hetero: bool = False, callback_fn: Callable = None ) -> NoReturn: @@ -707,15 +705,29 @@ def _read_graph_data( "Spektral is not installed. Please install it to use spektral output." ) # Get raw data from queue and parse + vertex_buffer = [] + edge_buffer = [] + buffer_size = 0 while not exit_event.is_set(): - raw = in_q.get() - if raw is None: + try: + raw = in_q.get(timeout=1) + except Empty: + continue + # if shuffle the data, 50% chance to save this data point for later + if shuffle and (random.random() < 0.5): in_q.task_done() - out_q.put(None) - break + in_q.put(raw) + continue + # Store raw into buffer until there are enough data points for a batch + vertex_buffer.extend(raw[0].splitlines()) + edge_buffer.extend(raw[1].splitlines()) + buffer_size += 1 + in_q.task_done() + if buffer_size < batch_size: + continue try: data = BaseLoader._parse_graph_data_to_df( - raw = raw, + raw = (vertex_buffer, edge_buffer), v_in_feats = v_in_feats, v_out_labels = v_out_labels, v_extra_feats = v_extra_feats, @@ -755,7 +767,6 @@ def _read_graph_data( e_extra_feats = e_extra_feats, e_attr_types = e_attr_types, add_self_loop = add_self_loop, - reindex = reindex, is_hetero = is_hetero, torch = torch, pyg = pyg @@ -772,7 +783,6 @@ def _read_graph_data( e_extra_feats = e_extra_feats, e_attr_types = e_attr_types, add_self_loop = add_self_loop, - reindex = reindex, is_hetero = is_hetero, torch = torch, dgl= dgl @@ -789,7 +799,6 @@ def _read_graph_data( e_extra_feats = e_extra_feats, e_attr_types = e_attr_types, add_self_loop = add_self_loop, - reindex = reindex, is_hetero = is_hetero, scipy = scipy, spektral = spektral @@ -802,15 +811,20 @@ def _read_graph_data( except Exception as err: warnings.warn("Error parsing a graph batch. Set logging level to ERROR for details.") logger.error(err, exc_info=True) - logger.error("Error parsing data: {}".format(raw)) + logger.error("Error parsing data: {}".format((vertex_buffer, edge_buffer))) logger.error("Parameters:\n out_format={}\n v_in_feats={}\n v_out_labels={}\n v_extra_feats={}\n v_attr_types={}\n e_in_feats={}\n e_out_labels={}\n e_extra_feats={}\n e_attr_types={}\n delimiter={}\n".format( out_format, v_in_feats, v_out_labels, v_extra_feats, v_attr_types, e_in_feats, e_out_labels, e_extra_feats, e_attr_types, delimiter)) + vertex_buffer.clear() + edge_buffer.clear() + buffer_size = 0 @staticmethod def _read_vertex_data( exit_event: Event, in_q: Queue, out_q: Queue, + batch_size: int, + shuffle: bool = False, v_in_feats: Union[list, dict] = [], v_out_labels: Union[list, dict] = [], v_extra_feats: Union[list, dict] = [], @@ -819,15 +833,25 @@ def _read_vertex_data( is_hetero: bool = False, callback_fn: Callable = None ) -> NoReturn: + buffer = [] while not exit_event.is_set(): - raw = in_q.get() - if raw is None: + try: + raw = in_q.get(timeout=1) + except Empty: + continue + # if shuffle the data, 50% chance to save this data point for later + if shuffle and (random.random() < 0.5): in_q.task_done() - out_q.put(None) - break + in_q.put(raw) + continue + # Store raw into buffer until there are enough data points for a batch + buffer.append(raw) + in_q.task_done() + if len(buffer) < batch_size: + continue try: data = BaseLoader._parse_vertex_data( - raw = raw, + raw = buffer, v_in_feats = v_in_feats, v_out_labels = v_out_labels, v_extra_feats = v_extra_feats, @@ -848,15 +872,18 @@ def _read_vertex_data( except Exception as err: warnings.warn("Error parsing a vertex batch. Set logging level to ERROR for details.") logger.error(err, exc_info=True) - logger.error("Error parsing data: {}".format(raw)) + logger.error("Error parsing data: {}".format(buffer)) logger.error("Parameters:\n v_in_feats={}\n v_out_labels={}\n v_extra_feats={}\n v_attr_types={}\n delimiter={}\n".format( v_in_feats, v_out_labels, v_extra_feats, v_attr_types, delimiter)) + buffer.clear() @staticmethod def _read_edge_data( exit_event: Event, in_q: Queue, out_q: Queue, + batch_size: int, + shuffle: bool = False, e_in_feats: Union[list, dict] = [], e_out_labels: Union[list, dict] = [], e_extra_feats: Union[list, dict] = [], @@ -865,15 +892,25 @@ def _read_edge_data( is_hetero: bool = False, callback_fn: Callable = None ) -> NoReturn: + buffer = [] while not exit_event.is_set(): - raw = in_q.get() - if raw is None: + try: + raw = in_q.get(timeout=1) + except Empty: + continue + # if shuffle the data, 50% chance to save this data point for later + if shuffle and (random.random() < 0.5): in_q.task_done() - out_q.put(None) - break + in_q.put(raw) + continue + # Store raw into buffer until there are enough data points for a batch + buffer.append(raw) + in_q.task_done() + if len(buffer) < batch_size: + continue try: data = BaseLoader._parse_edge_data( - raw = raw, + raw = buffer, e_in_feats = e_in_feats, e_out_labels = e_out_labels, e_extra_feats = e_extra_feats, @@ -894,9 +931,10 @@ def _read_edge_data( except Exception as err: warnings.warn("Error parsing an edge batch. Set logging level to ERROR for details.") logger.error(err, exc_info=True) - logger.error("Error parsing data: {}".format(raw)) + logger.error("Error parsing data: {}".format(buffer)) logger.error("Parameters:\n e_in_feats={}\n e_out_labels={}\n e_extra_feats={}\n e_attr_types={}\n delimiter={}\n".format( e_in_feats, e_out_labels, e_extra_feats, e_attr_types, delimiter)) + buffer.clear() @staticmethod def _parse_vertex_data( @@ -910,19 +948,19 @@ def _parse_vertex_data( ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: """Parse raw vertex data into dataframes. """ - # Read in vertex CSVs as dataframes + # Read in vertex CSVs as dataframes + # Each row is in format vid,v_in_feats,v_out_labels,v_extra_feats + # or vtype,vid,v_in_feats,v_out_labels,v_extra_feats + v_file = (line.strip().split(delimiter) for line in raw) if not is_hetero: # String of vertices in format vid,v_in_feats,v_out_labels,v_extra_feats v_attributes = ["vid"] + v_in_feats + v_out_labels + v_extra_feats - v_file = (line.strip().split(delimiter) for line in raw.splitlines()) data = pd.DataFrame(v_file, columns=v_attributes, dtype="object") for v_attr in v_extra_feats: if v_attr_types.get(v_attr, "") == "MAP": # I am sorry that this is this ugly... data[v_attr] = data[v_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) else: - # String of vertices in format vtype,vid,v_in_feats,v_out_labels,v_extra_feats - v_file = (line.strip().split(delimiter) for line in raw.splitlines()) v_file_dict = defaultdict(list) for line in v_file: v_file_dict[line[0]].append(line[1:]) @@ -951,21 +989,18 @@ def _parse_edge_data( ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: """Parse raw edge data into dataframes. """ - # Read in edge CSVs as dataframes + # Read in edge CSVs as dataframes + # Each row is in format source_vid,target_vid,e_in_feats,e_out_labels,e_extra_feats + # or etype,source_vid,target_vid,e_in_feats,e_out_labels,e_extra_feats + e_file = (line.strip().split(delimiter) for line in raw) if not is_hetero: - # String of edges in format source_vid,target_vid,... e_attributes = ["source", "target"] + e_in_feats + e_out_labels + e_extra_feats - #file = "\n".join(x for x in raw.split("\n") if x.strip()) - #data = pd.read_table(io.StringIO(file), header=None, names=e_attributes, sep=delimiter) - e_file = (line.strip().split(delimiter) for line in raw.splitlines()) data = pd.DataFrame(e_file, columns=e_attributes, dtype="object") for e_attr in e_extra_feats: if e_attr_types.get(e_attr, "") == "MAP": # I am sorry that this is this ugly... data[e_attr] = data[e_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) else: - # String of edges in format etype,source_vid,target_vid,... - e_file = (line.strip().split(delimiter) for line in raw.splitlines()) e_file_dict = defaultdict(list) for line in e_file: e_file_dict[line[0]].append(line[1:]) @@ -979,7 +1014,7 @@ def _parse_edge_data( for e_attr in e_extra_feats.get(etype, []): if e_attr_types[etype][e_attr] == "MAP": # I am sorry that this is this ugly... - data[etype][e_attr] = data[etype][e_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) + data[etype][e_attr] = data[etype][e_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) return data @staticmethod @@ -1028,7 +1063,7 @@ def _parse_graph_data_to_df( e_attr_types = e_attr_types, delimiter = delimiter, is_hetero = is_hetero - ) + ) return (vertices, edges) @staticmethod @@ -1108,7 +1143,6 @@ def _parse_df_to_pyg( e_extra_feats: Union[list, dict] = [], e_attr_types: dict = {}, add_self_loop: bool = False, - reindex: bool = True, is_hetero: bool = False, torch = None, pyg = None @@ -1291,7 +1325,6 @@ def _parse_df_to_dgl( e_extra_feats: Union[list, dict] = [], e_attr_types: dict = {}, add_self_loop: bool = False, - reindex: bool = True, is_hetero: bool = False, torch = None, dgl = None @@ -1491,7 +1524,6 @@ def _parse_df_to_spektral( e_extra_feats: Union[list, dict] = [], e_attr_types: dict = {}, add_self_loop: bool = False, - reindex: bool = True, is_hetero: bool = False, scipy = None, spektral = None diff --git a/tests/test_gds_BaseLoader.py b/tests/test_gds_BaseLoader.py index af75cec6..396ef453 100644 --- a/tests/test_gds_BaseLoader.py +++ b/tests/test_gds_BaseLoader.py @@ -1,7 +1,7 @@ import io import unittest -from queue import Queue -from threading import Event +from queue import Queue, Empty +from threading import Event, Thread import pandas as pd import torch @@ -143,101 +143,170 @@ def test_read_vertex(self): read_task_q = Queue() data_q = Queue(4) exit_event = Event() - raw = "99|1 0 0 1 |1|0|1\n8|1 0 0 1 |1|1|1\n" - read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_vertex_data( - exit_event, - read_task_q, - data_q, - ["x"], - ["y"], - ["train_mask", "is_seed"], - {"x": "INT", "y": "INT", "train_mask": "BOOL", "is_seed": "BOOL"}, - delimiter="|" - ) + raw = ["99|1 0 0 1 |1|0|1\n", + "8|1 0 0 1 |1|1|1\n"] + for i in raw: + read_task_q.put(i) + thread = Thread( + target=self.loader._read_vertex_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 2, + v_in_feats = ["x"], + v_out_labels = ["y"], + v_extra_feats = ["train_mask", "is_seed"], + v_attr_types = {"x": "INT", "y": "INT", "train_mask": "BOOL", "is_seed": "BOOL"}, + delimiter = "|" + ) + ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() truth = pd.read_csv( - io.StringIO(raw), + io.StringIO("".join(raw)), header=None, names=["vid", "x", "y", "train_mask", "is_seed"], sep=self.loader.delimiter ) assert_frame_equal(data, truth) + + def test_read_vertex_shuffle(self): + read_task_q = Queue() + data_q = Queue(4) + exit_event = Event() + raw = ["99|1 0 0 1 |1|0|1\n", + "8|1 0 0 1 |1|1|1\n"] + for i in raw: + read_task_q.put(i) + thread = Thread( + target=self.loader._read_vertex_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 2, + shuffle= True, + v_in_feats = ["x"], + v_out_labels = ["y"], + v_extra_feats = ["train_mask", "is_seed"], + v_attr_types = {"x": "INT", "y": "INT", "train_mask": "BOOL", "is_seed": "BOOL"}, + delimiter = "|" + ) + ) + thread.start() data = data_q.get() - self.assertIsNone(data) + exit_event.set() + thread.join() + truth1 = pd.read_csv( + io.StringIO("".join(raw)), + header=None, + names=["vid", "x", "y", "train_mask", "is_seed"], + sep=self.loader.delimiter + ) + raw.reverse() + truth2 = pd.read_csv( + io.StringIO("".join(raw)), + header=None, + names=["vid", "x", "y", "train_mask", "is_seed"], + sep=self.loader.delimiter + ) + self.assertTrue((data==truth1).all().all() or (data==truth2).all().all()) def test_read_vertex_callback(self): read_task_q = Queue() data_q = Queue(4) exit_event = Event() - raw = "99|1 0 0 1 |1|0|1\n8|1 0 0 1 |1|1|1\n" - read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_vertex_data( - exit_event, - read_task_q, - data_q, - ["x"], - ["y"], - ["train_mask", "is_seed"], - {"x": "INT", "y": "INT", "train_mask": "BOOL", "is_seed": "BOOL"}, - callback_fn=lambda x: 1, - delimiter="|" - ) + raw = ["99|1 0 0 1 |1|0|1\n", + "8|1 0 0 1 |1|1|1\n"] + for i in raw: + read_task_q.put(i) + thread = Thread( + target=self.loader._read_vertex_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 2, + v_in_feats = ["x"], + v_out_labels = ["y"], + v_extra_feats = ["train_mask", "is_seed"], + v_attr_types = {"x": "INT", "y": "INT", "train_mask": "BOOL", "is_seed": "BOOL"}, + delimiter = "|", + callback_fn = lambda x: 1 + ) + ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() self.assertEqual(1, data) def test_read_edge(self): read_task_q = Queue() data_q = Queue(4) exit_event = Event() - raw = "1|2|0.1|2021|1|0\n2|1|1.5|2020|0|1\n" - read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_edge_data( - exit_event, - read_task_q, - data_q, - ["x", "time"], - ["y"], - ["is_train"], - {"x": "FLOAT", "time": "INT", "y": "INT", "is_train": "BOOL"}, - delimiter="|" - ) + raw = ["1|2|0.1|2021|1|0\n", + "2|1|1.5|2020|0|1\n"] + for i in raw: + read_task_q.put(i) + thread = Thread( + target=self.loader._read_edge_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 2, + e_in_feats = ["x", "time"], + e_out_labels = ["y"], + e_extra_feats = ["is_train"], + e_attr_types = {"x": "FLOAT", "time": "INT", "y": "INT", "is_train": "BOOL"}, + delimiter = "|" + ) + ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() truth = pd.read_csv( - io.StringIO(raw), + io.StringIO("".join(raw)), header=None, names=["source", "target", "x", "time", "y", "is_train"], sep=self.loader.delimiter, ) assert_frame_equal(data, truth) - data = data_q.get() - self.assertIsNone(data) def test_read_edge_callback(self): read_task_q = Queue() data_q = Queue(4) exit_event = Event() - raw = "1|2|0.1|2021|1|0\n2|1|1.5|2020|0|1\n" - read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_edge_data( - exit_event, - read_task_q, - data_q, - ["x", "time"], - ["y"], - ["is_train"], - {"x": "FLOAT", "time": "INT", "y": "INT", "is_train": "BOOL"}, - callback_fn=lambda x: 1, - delimiter="|" - ) + raw = ["1|2|0.1|2021|1|0\n", + "2|1|1.5|2020|0|1\n"] + for i in raw: + read_task_q.put(i) + thread = Thread( + target=self.loader._read_edge_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 2, + e_in_feats = ["x", "time"], + e_out_labels = ["y"], + e_extra_feats = ["is_train"], + e_attr_types = {"x": "FLOAT", "time": "INT", "y": "INT", "is_train": "BOOL"}, + delimiter = "|", + callback_fn=lambda x: 1 + ) + ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() self.assertEqual(data, 1) - def test_read_graph_out_df(self): read_task_q = Queue() data_q = Queue(4) @@ -247,23 +316,28 @@ def test_read_graph_out_df(self): "1|2|0.1|2021|1|0\n2|1|1.5|2020|0|1\n", ) read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_graph_data( - exit_event, - read_task_q, - data_q, - "dataframe", - ["x"], - ["y"], - ["train_mask", "is_seed"], - {"x": "INT", "y": "INT", "train_mask": "BOOL", "is_seed": "BOOL"}, - ["x", "time"], - ["y"], - ["is_train"], - {"x": "FLOAT", "time": "INT", "y": "INT", "is_train": "BOOL"}, - delimiter="|" - ) + thread = Thread( + target=self.loader._read_graph_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 1, + v_in_feats = ["x"], + v_out_labels = ["y"], + v_extra_feats = ["train_mask", "is_seed"], + v_attr_types = {"x": "INT", "y": "INT", "train_mask": "BOOL", "is_seed": "BOOL"}, + e_in_feats = ["x", "time"], + e_out_labels = ["y"], + e_extra_feats = ["is_train"], + e_attr_types = {"x": "FLOAT", "time": "INT", "y": "INT", "is_train": "BOOL"}, + delimiter = "|" + ) + ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() vertices = pd.read_csv( io.StringIO(raw[0]), header=None, @@ -278,9 +352,6 @@ def test_read_graph_out_df(self): ) assert_frame_equal(data[0], vertices) assert_frame_equal(data[1], edges) - data = data_q.get() - self.assertIsNone(data) - def test_read_graph_out_df_callback(self): read_task_q = Queue() @@ -291,28 +362,32 @@ def test_read_graph_out_df_callback(self): "1|2|0.1|2021|1|0\n2|1|1.5|2020|0|1\n", ) read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_graph_data( - exit_event, - read_task_q, - data_q, - "dataframe", - ["x"], - ["y"], - ["train_mask", "is_seed"], - {"x": "INT", "y": "INT", "train_mask": "BOOL", "is_seed": "BOOL"}, - ["x", "time"], - ["y"], - ["is_train"], - {"x": "FLOAT", "time": "INT", "y": "INT", "is_train": "BOOL"}, - callback_fn=lambda x: (1, 2), - delimiter="|" - ) + thread = Thread( + target=self.loader._read_graph_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 1, + v_in_feats = ["x"], + v_out_labels = ["y"], + v_extra_feats = ["train_mask", "is_seed"], + v_attr_types = {"x": "INT", "y": "INT", "train_mask": "BOOL", "is_seed": "BOOL"}, + e_in_feats = ["x", "time"], + e_out_labels = ["y"], + e_extra_feats = ["is_train"], + e_attr_types = {"x": "FLOAT", "time": "INT", "y": "INT", "is_train": "BOOL"}, + delimiter = "|", + callback_fn = lambda x: (1, 2), + ) + ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() self.assertEqual(data[0], 1) self.assertEqual(data[1], 2) - def test_read_graph_out_pyg(self): read_task_q = Queue() data_q = Queue(4) @@ -322,29 +397,36 @@ def test_read_graph_out_pyg(self): "99|8|0.1|2021|1|0|a b \n8|99|1.5|2020|0|1|c d \n", ) read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_graph_data( - exit_event, - read_task_q, - data_q, - "pyg", - ["x"], - ["y"], - ["train_mask", "name", "is_seed"], - { - "x": "LIST:INT", - "y": "INT", - "train_mask": "BOOL", - "name": "STRING", - "is_seed": "BOOL", - }, - ["x", "time"], - ["y"], - ["is_train", "category"], - {"x": "DOUBLE", "time": "INT", "y": "INT", "is_train": "BOOL", "category": "LIST:STRING"}, - delimiter="|" - ) + thread = Thread( + target=self.loader._read_graph_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 1, + out_format = "pyg", + v_in_feats = ["x"], + v_out_labels = ["y"], + v_extra_feats = ["train_mask", "name", "is_seed"], + v_attr_types = + { + "x": "LIST:INT", + "y": "INT", + "train_mask": "BOOL", + "name": "STRING", + "is_seed": "BOOL", + }, + e_in_feats = ["x", "time"], + e_out_labels = ["y"], + e_extra_feats = ["is_train", "category"], + e_attr_types = {"x": "DOUBLE", "time": "INT", "y": "INT", "is_train": "BOOL", "category": "LIST:STRING"}, + delimiter = "|" + ) + ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() self.assertIsInstance(data, pygData) assert_close_torch(data["edge_index"], torch.tensor([[0, 1], [1, 0]])) assert_close_torch( @@ -359,8 +441,6 @@ def test_read_graph_out_pyg(self): assert_close_torch(data["is_seed"], torch.tensor([True, False])) self.assertListEqual(data["name"], ["Alex", "Bill"]) self.assertListEqual(data["category"], [['a', 'b'], ['c', 'd']]) - data = data_q.get() - self.assertIsNone(data) def test_read_graph_out_dgl(self): read_task_q = Queue() @@ -371,29 +451,36 @@ def test_read_graph_out_dgl(self): "99|8|0.1|2021|1|0|a b \n8|99|1.5|2020|0|1|c d \n", ) read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_graph_data( - exit_event, - read_task_q, - data_q, - "dgl", - ["x"], - ["y"], - ["train_mask", "name", "is_seed"], - { - "x": "LIST:INT", - "y": "INT", - "train_mask": "BOOL", - "name": "STRING", - "is_seed": "BOOL", - }, - ["x", "time"], - ["y"], - ["is_train", "category"], - {"x": "DOUBLE", "time": "INT", "y": "INT", "is_train": "BOOL", "category": "LIST:STRING"}, - delimiter="|" - ) + thread = Thread( + target=self.loader._read_graph_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 1, + out_format = "dgl", + v_in_feats = ["x"], + v_out_labels = ["y"], + v_extra_feats = ["train_mask", "name", "is_seed"], + v_attr_types = + { + "x": "LIST:INT", + "y": "INT", + "train_mask": "BOOL", + "name": "STRING", + "is_seed": "BOOL", + }, + e_in_feats = ["x", "time"], + e_out_labels = ["y"], + e_extra_feats = ["is_train", "category"], + e_attr_types = {"x": "DOUBLE", "time": "INT", "y": "INT", "is_train": "BOOL", "category": "LIST:STRING"}, + delimiter = "|" + ) + ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() self.assertIsInstance(data, DGLGraph) assert_close_torch(data.edges(), (torch.tensor([0, 1]), torch.tensor([1, 0]))) assert_close_torch( @@ -408,8 +495,6 @@ def test_read_graph_out_dgl(self): assert_close_torch(data.ndata["is_seed"], torch.tensor([True, False])) self.assertListEqual(data.extra_data["name"], ["Alex", "Bill"]) self.assertListEqual(data.extra_data["category"], [['a', 'b'], ['c', 'd']]) - data = data_q.get() - self.assertIsNone(data) def test_read_graph_parse_error(self): read_task_q = Queue() @@ -420,30 +505,37 @@ def test_read_graph_parse_error(self): "99|8|0.1|2021|1|0|a b \n8|99|1.5|2020|0|1|c d \n", ) read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_graph_data( - exit_event, - read_task_q, - data_q, - "dgl", - ["x"], - ["y"], - ["train_mask", "name", "is_seed"], - { - "x": "LIST:INT", - "y": "INT", - "train_mask": "BOOL", - "name": "STRING", - "is_seed": "BOOL", - }, - ["x", "time"], - ["y"], - ["is_train", "category"], - {"x": "DOUBLE", "time": "INT", "y": "INT", "is_train": "BOOL", "category": "LIST:STRING"}, - delimiter="|" - ) - data = data_q.get() - self.assertIsNone(data) + thread = Thread( + target=self.loader._read_graph_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 1, + out_format = "dgl", + v_in_feats = ["x"], + v_out_labels = ["y"], + v_extra_feats = ["train_mask", "name", "is_seed"], + v_attr_types = + { + "x": "LIST:INT", + "y": "INT", + "train_mask": "BOOL", + "name": "STRING", + "is_seed": "BOOL", + }, + e_in_feats = ["x", "time"], + e_out_labels = ["y"], + e_extra_feats = ["is_train", "category"], + e_attr_types = {"x": "DOUBLE", "time": "INT", "y": "INT", "is_train": "BOOL", "category": "LIST:STRING"}, + delimiter = "|" + ) + ) + thread.start() + with self.assertRaises(Empty): + data = data_q.get(timeout=1) + exit_event.set() + thread.join() def test_read_graph_no_attr(self): read_task_q = Queue() @@ -451,34 +543,33 @@ def test_read_graph_no_attr(self): exit_event = Event() raw = ("99|1\n8|0\n", "99|8\n8|99\n") read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_graph_data( - exit_event, - read_task_q, - data_q, - "pyg", - [], - [], - ["is_seed"], - { - "x": "INT", - "y": "INT", - "train_mask": "BOOL", - "name": "STRING", - "is_seed": "BOOL", - }, - [], - [], - [], - {}, - delimiter="|" - ) + thread = Thread( + target=self.loader._read_graph_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 1, + out_format = "pyg", + v_extra_feats = ["is_seed"], + v_attr_types = + { + "x": "LIST:INT", + "y": "INT", + "train_mask": "BOOL", + "name": "STRING", + "is_seed": "BOOL", + }, + delimiter = "|" + ) + ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() self.assertIsInstance(data, pygData) assert_close_torch(data["edge_index"], torch.tensor([[0, 1], [1, 0]])) assert_close_torch(data["is_seed"], torch.tensor([True, False])) - data = data_q.get() - self.assertIsNone(data) def test_read_graph_no_edge(self): read_task_q = Queue() @@ -489,29 +580,36 @@ def test_read_graph_no_edge(self): "", ) read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_graph_data( - exit_event, - read_task_q, - data_q, - "pyg", - ["x"], - ["y"], - ["train_mask", "name", "is_seed"], - { - "x": "LIST:INT", - "y": "INT", - "train_mask": "BOOL", - "name": "STRING", - "is_seed": "BOOL", - }, - ["x", "time"], - ["y"], - ["is_train"], - {"x": "DOUBLE", "time": "INT", "y": "INT", "is_train": "BOOL"}, - delimiter="|" - ) + thread = Thread( + target=self.loader._read_graph_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 1, + out_format = "pyg", + v_in_feats = ["x"], + v_out_labels = ["y"], + v_extra_feats = ["train_mask", "name", "is_seed"], + v_attr_types = + { + "x": "LIST:INT", + "y": "INT", + "train_mask": "BOOL", + "name": "STRING", + "is_seed": "BOOL", + }, + e_in_feats = ["x", "time"], + e_out_labels = ["y"], + e_extra_feats = ["is_train"], + e_attr_types = {"x": "DOUBLE", "time": "INT", "y": "INT", "is_train": "BOOL"}, + delimiter = "|" + ) + ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() self.assertIsInstance(data, pygData) self.assertListEqual(list(data["edge_index"].shape), [2,0]) self.assertListEqual(list(data["edge_feat"].shape), [0,2]) @@ -522,8 +620,6 @@ def test_read_graph_no_edge(self): assert_close_torch(data["train_mask"], torch.tensor([False, True])) assert_close_torch(data["is_seed"], torch.tensor([True, False])) self.assertListEqual(data["name"], ["Alex", "Bill"]) - data = data_q.get() - self.assertIsNone(data) def test_read_hetero_graph_out_pyg(self): read_task_q = Queue() @@ -534,50 +630,53 @@ def test_read_hetero_graph_out_pyg(self): "Colleague|99|8|0.1|2021|1|0\nColleague|8|99|1.5|2020|0|1\nWork|99|2\nWork|2|8\n", ) read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_graph_data( - exit_event, - read_task_q, - data_q, - "pyg", - {"People": ["x"], "Company": ["x"]}, - {"People": ["y"]}, - {"People": ["train_mask", "name", "is_seed"], "Company": ["is_seed"]}, - { - "People": { - "x": "LIST:INT", - "y": "INT", - "train_mask": "BOOL", - "name": "STRING", - "is_seed": "BOOL", + thread = Thread( + target=self.loader._read_graph_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 1, + out_format = "pyg", + v_in_feats = {"People": ["x"], "Company": ["x"]}, + v_out_labels = {"People": ["y"]}, + v_extra_feats = {"People": ["train_mask", "name", "is_seed"], "Company": ["is_seed"]}, + v_attr_types = + { + "People": { + "x": "LIST:INT", + "y": "INT", + "train_mask": "BOOL", + "name": "STRING", + "is_seed": "BOOL", + }, + "Company": {"x": "FLOAT", "is_seed": "BOOL"}, + }, + e_in_feats = {"Colleague": ["x", "time"]}, + e_out_labels = {"Colleague": ["y"]}, + e_extra_feats = {"Colleague": ["is_train"]}, + e_attr_types = { + "Colleague": { + "FromVertexTypeName": "People", + "ToVertexTypeName": "People", + "IsDirected": False, + "x": "DOUBLE", + "time": "INT", + "y": "INT", + "is_train": "BOOL"}, + "Work": { + "FromVertexTypeName": "People", + "ToVertexTypeName": "Company", + "IsDirected": False} }, - "Company": {"x": "FLOAT", "is_seed": "BOOL"}, - }, - {"Colleague": ["x", "time"]}, - {"Colleague": ["y"]}, - {"Colleague": ["is_train"]}, - { - "Colleague": { - "FromVertexTypeName": "People", - "ToVertexTypeName": "People", - "IsDirected": False, - "x": "DOUBLE", - "time": "INT", - "y": "INT", - "is_train": "BOOL", - }, - "Work": { - "FromVertexTypeName": "People", - "ToVertexTypeName": "Company", - "IsDirected": False, - } - }, - False, - "|", - True, - True, + delimiter = "|", + is_hetero = True + ) ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() self.assertIsInstance(data, pygHeteroData) assert_close_torch( data["Colleague"]["edge_index"], torch.tensor([[0, 1], [1, 0]]) @@ -602,8 +701,6 @@ def test_read_hetero_graph_out_pyg(self): assert_close_torch( data["Work"]["edge_index"], torch.tensor([[0, 1], [0, 0]]) ) - data = data_q.get() - self.assertIsNone(data) def test_read_hetero_graph_no_attr(self): read_task_q = Queue() @@ -614,50 +711,53 @@ def test_read_hetero_graph_no_attr(self): "Colleague|99|8\nColleague|8|99\nWork|99|2\nWork|2|8\n", ) read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_graph_data( - exit_event, - read_task_q, - data_q, - "pyg", - {"People": [], "Company": []}, - {"People": [], "Company": []}, - {"People": ["is_seed"], "Company": ["is_seed"]}, - { - "People": { - "x": "LIST:INT", - "y": "INT", - "train_mask": "BOOL", - "name": "STRING", - "is_seed": "BOOL", + thread = Thread( + target=self.loader._read_graph_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 1, + out_format = "pyg", + v_in_feats = {"People": [], "Company": []}, + v_out_labels = {"People": [], "Company": []}, + v_extra_feats = {"People": ["is_seed"], "Company": ["is_seed"]}, + v_attr_types = + { + "People": { + "x": "LIST:INT", + "y": "INT", + "train_mask": "BOOL", + "name": "STRING", + "is_seed": "BOOL", + }, + "Company": {"x": "FLOAT", "is_seed": "BOOL"}, + }, + e_in_feats = {"Colleague": [], "Work": []}, + e_out_labels = {"Colleague": [], "Work": []}, + e_extra_feats = {"Colleague": [], "Work": []}, + e_attr_types = { + "Colleague": { + "FromVertexTypeName": "People", + "ToVertexTypeName": "People", + "IsDirected": False, + "x": "DOUBLE", + "time": "INT", + "y": "INT", + "is_train": "BOOL"}, + "Work": { + "FromVertexTypeName": "People", + "ToVertexTypeName": "Company", + "IsDirected": False} }, - "Company": {"x": "FLOAT", "is_seed": "BOOL"}, - }, - {"Colleague": [], "Work": []}, - {"Colleague": [], "Work": []}, - {"Colleague": [], "Work": []}, - { - "Colleague": { - "FromVertexTypeName": "People", - "ToVertexTypeName": "People", - "IsDirected": False, - "x": "DOUBLE", - "time": "INT", - "y": "INT", - "is_train": "BOOL", - }, - "Work": { - "FromVertexTypeName": "People", - "ToVertexTypeName": "Company", - "IsDirected": False, - } - }, - False, - "|", - True, - True, + delimiter = "|", + is_hetero = True + ) ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() self.assertIsInstance(data, pygHeteroData) assert_close_torch( data["Colleague"]["edge_index"], torch.tensor([[0, 1], [1, 0]]) @@ -667,8 +767,6 @@ def test_read_hetero_graph_no_attr(self): ) assert_close_torch(data["People"]["is_seed"], torch.tensor([True, False])) assert_close_torch(data["Company"]["is_seed"], torch.tensor([False])) - data = data_q.get() - self.assertIsNone(data) def test_read_hetero_graph_no_edge(self): read_task_q = Queue() @@ -679,50 +777,53 @@ def test_read_hetero_graph_no_edge(self): "", ) read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_graph_data( - exit_event, - read_task_q, - data_q, - "pyg", - {"People": ["x"], "Company": ["x"]}, - {"People": ["y"]}, - {"People": ["train_mask", "name", "is_seed"], "Company": ["is_seed"]}, - { - "People": { - "x": "LIST:INT", - "y": "INT", - "train_mask": "BOOL", - "name": "STRING", - "is_seed": "BOOL", + thread = Thread( + target=self.loader._read_graph_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 1, + out_format = "pyg", + v_in_feats = {"People": ["x"], "Company": ["x"]}, + v_out_labels = {"People": ["y"]}, + v_extra_feats = {"People": ["train_mask", "name", "is_seed"], "Company": ["is_seed"]}, + v_attr_types = + { + "People": { + "x": "LIST:INT", + "y": "INT", + "train_mask": "BOOL", + "name": "STRING", + "is_seed": "BOOL", + }, + "Company": {"x": "FLOAT", "is_seed": "BOOL"}, + }, + e_in_feats = {"Colleague": ["x", "time"]}, + e_out_labels = {"Colleague": ["y"]}, + e_extra_feats = {"Colleague": ["is_train"]}, + e_attr_types = { + "Colleague": { + "FromVertexTypeName": "People", + "ToVertexTypeName": "People", + "IsDirected": False, + "x": "DOUBLE", + "time": "INT", + "y": "INT", + "is_train": "BOOL"}, + "Work": { + "FromVertexTypeName": "People", + "ToVertexTypeName": "Company", + "IsDirected": False} }, - "Company": {"x": "FLOAT", "is_seed": "BOOL"}, - }, - {"Colleague": ["x", "time"]}, - {"Colleague": ["y"]}, - {"Colleague": ["is_train"]}, - { - "Colleague": { - "FromVertexTypeName": "People", - "ToVertexTypeName": "People", - "IsDirected": False, - "x": "DOUBLE", - "time": "INT", - "y": "INT", - "is_train": "BOOL", - }, - "Work": { - "FromVertexTypeName": "People", - "ToVertexTypeName": "Company", - "IsDirected": False, - } - }, - False, - "|", - True, - True, + delimiter = "|", + is_hetero = True + ) ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() self.assertIsInstance(data, pygHeteroData) self.assertNotIn("Colleague", data) assert_close_torch( @@ -737,8 +838,6 @@ def test_read_hetero_graph_no_edge(self): ) assert_close_torch(data["Company"]["is_seed"], torch.tensor([False])) self.assertNotIn("Work", data) - data = data_q.get() - self.assertIsNone(data) def test_read_hetero_graph_out_dgl(self): read_task_q = Queue() @@ -749,51 +848,54 @@ def test_read_hetero_graph_out_dgl(self): "Colleague|99|8|0.1|2021|1|0\nColleague|8|99|1.5|2020|0|1\nWork|99|2|a b \nWork|2|8|c d \n", ) read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_graph_data( - exit_event, - read_task_q, - data_q, - "dgl", - {"People": ["x"], "Company": ["x"]}, - {"People": ["y"]}, - {"People": ["train_mask", "name", "is_seed"], "Company": ["is_seed"]}, - { - "People": { - "x": "LIST:INT", - "y": "INT", - "train_mask": "BOOL", - "name": "STRING", - "is_seed": "BOOL", - }, - "Company": {"x": "FLOAT", "is_seed": "BOOL"}, - }, - {"Colleague": ["x", "time"]}, - {"Colleague": ["y"]}, - {"Colleague": ["is_train"], "Work": ["category"]}, - { - "Colleague": { - "FromVertexTypeName": "People", - "ToVertexTypeName": "People", - "IsDirected": False, - "x": "DOUBLE", - "time": "INT", - "y": "INT", - "is_train": "BOOL", + thread = Thread( + target=self.loader._read_graph_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 1, + out_format = "dgl", + v_in_feats = {"People": ["x"], "Company": ["x"]}, + v_out_labels = {"People": ["y"]}, + v_extra_feats = {"People": ["train_mask", "name", "is_seed"], "Company": ["is_seed"]}, + v_attr_types = + { + "People": { + "x": "LIST:INT", + "y": "INT", + "train_mask": "BOOL", + "name": "STRING", + "is_seed": "BOOL", + }, + "Company": {"x": "FLOAT", "is_seed": "BOOL"}, + }, + e_in_feats = {"Colleague": ["x", "time"]}, + e_out_labels = {"Colleague": ["y"]}, + e_extra_feats = {"Colleague": ["is_train"], "Work": ["category"]}, + e_attr_types = { + "Colleague": { + "FromVertexTypeName": "People", + "ToVertexTypeName": "People", + "IsDirected": False, + "x": "DOUBLE", + "time": "INT", + "y": "INT", + "is_train": "BOOL"}, + "Work": { + "FromVertexTypeName": "People", + "ToVertexTypeName": "Company", + "IsDirected": False, + "category": "LIST:STRING"} }, - "Work": { - "FromVertexTypeName": "People", - "ToVertexTypeName": "Company", - "IsDirected": False, - "category": "LIST:STRING" - } - }, - False, - "|", - True, - True, + delimiter = "|", + is_hetero = True + ) ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() self.assertIsInstance(data, DGLGraph) assert_close_torch( data.edges(etype="Colleague"), (torch.tensor([0, 1]), torch.tensor([1, 0])) @@ -819,8 +921,6 @@ def test_read_hetero_graph_out_dgl(self): data.edges(etype="Work"), (torch.tensor([0, 1]), torch.tensor([0, 0])) ) self.assertListEqual(data.extra_data["Work"]["category"], [['a', 'b'], ['c', 'd']]) - data = data_q.get() - self.assertIsNone(data) def test_read_bool_label(self): read_task_q = Queue() @@ -831,29 +931,36 @@ def test_read_bool_label(self): "99|8|0.1|2021|1|0\n8|99|1.5|2020|0|1\n", ) read_task_q.put(raw) - read_task_q.put(None) - self.loader._read_graph_data( - exit_event, - read_task_q, - data_q, - "pyg", - ["x"], - ["y"], - ["train_mask", "name", "is_seed"], - { - "x": "LIST:INT", - "y": "BOOL", - "train_mask": "BOOL", - "name": "STRING", - "is_seed": "BOOL", - }, - ["x", "time"], - ["y"], - ["is_train"], - {"x": "DOUBLE", "time": "INT", "y": "BOOL", "is_train": "BOOL"}, - delimiter="|" - ) + thread = Thread( + target=self.loader._read_graph_data, + kwargs=dict( + exit_event = exit_event, + in_q = read_task_q, + out_q = data_q, + batch_size = 1, + out_format = "pyg", + v_in_feats = ["x"], + v_out_labels = ["y"], + v_extra_feats = ["train_mask", "name", "is_seed"], + v_attr_types = + { + "x": "LIST:INT", + "y": "BOOL", + "train_mask": "BOOL", + "name": "STRING", + "is_seed": "BOOL", + }, + e_in_feats = ["x", "time"], + e_out_labels = ["y"], + e_extra_feats = ["is_train"], + e_attr_types = {"x": "DOUBLE", "time": "INT", "y": "BOOL", "is_train": "BOOL"}, + delimiter = "|" + ) + ) + thread.start() data = data_q.get() + exit_event.set() + thread.join() self.assertIsInstance(data, pygData) assert_close_torch(data["edge_index"], torch.tensor([[0, 1], [1, 0]])) assert_close_torch( @@ -867,17 +974,16 @@ def test_read_bool_label(self): assert_close_torch(data["train_mask"], torch.tensor([False, True])) assert_close_torch(data["is_seed"], torch.tensor([True, False])) self.assertListEqual(data["name"], ["Alex", "Bill"]) - data = data_q.get() - self.assertIsNone(data) if __name__ == "__main__": suite = unittest.TestSuite() - suite.addTest(TestGDSBaseLoader("test_get_schema")) - suite.addTest(TestGDSBaseLoader("test_get_schema_no_primary_id_attr")) - suite.addTest(TestGDSBaseLoader("test_validate_vertex_attributes")) - suite.addTest(TestGDSBaseLoader("test_validate_edge_attributes")) + # suite.addTest(TestGDSBaseLoader("test_get_schema")) + # suite.addTest(TestGDSBaseLoader("test_get_schema_no_primary_id_attr")) + # suite.addTest(TestGDSBaseLoader("test_validate_vertex_attributes")) + # suite.addTest(TestGDSBaseLoader("test_validate_edge_attributes")) suite.addTest(TestGDSBaseLoader("test_read_vertex")) + suite.addTest(TestGDSBaseLoader("test_read_vertex_shuffle")) suite.addTest(TestGDSBaseLoader("test_read_vertex_callback")) suite.addTest(TestGDSBaseLoader("test_read_edge")) suite.addTest(TestGDSBaseLoader("test_read_edge_callback")) From 9eaef0eb7241277383f6b29033ef70eded3f7b3b Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Tue, 10 Oct 2023 13:10:29 -0700 Subject: [PATCH 05/36] fix(BaseLoader): upate _start_request for the refactor --- pyTigerGraph/gds/dataloaders.py | 82 ++++++++++++++++++++++----------- 1 file changed, 55 insertions(+), 27 deletions(-) diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index 51f1dd9d..d77b3d35 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -1616,22 +1616,32 @@ def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, raise NotImplementedError return data - def _start_request(self, out_tuple: bool, resp_type: str): + def _start_request(self, is_graph: bool): # If using kafka if self.kafka_address_consumer: # Generate topic self._set_kafka_topic() # Start consumer thread - self._downloader = Thread( - target=self._download_from_kafka, - args=( - self._exit_event, - self._read_task_q, - self.num_batches, - out_tuple, - self._kafka_consumer, - ), - ) + if is_graph: + self._downloader = Thread( + target=self._download_graph_kafka, + kwargs=dict( + exit_event = self._exit_event, + read_task_q = self._read_task_q, + kafka_consumer = self._kafka_consumer, + max_wait_time = self.timeout + ), + ) + else: + self._downloader = Thread( + target=self._download_unimode_kafka, + kwargs=dict( + exit_event = self._exit_event, + read_task_q = self._read_task_q, + kafka_consumer = self._kafka_consumer, + max_wait_time = self.timeout + ), + ) self._downloader.start() # Start requester thread if not self.kafka_skip_produce: @@ -1648,17 +1658,28 @@ def _start_request(self, out_tuple: bool, resp_type: str): self._requester.start() else: # Otherwise, use rest api - self._requester = Thread( - target=self._request_rest, - args=( - self._graph, - self.query_name, - self._read_task_q, - self.timeout, - self._payload, - resp_type, - ), - ) + if is_graph: + self._requester = Thread( + target=self._request_graph_rest, + kwargs=dict( + tgraph = self._graph, + query_name = self.query_name, + read_task_q = self._read_task_q, + timeout = self.timeout, + payload = self._payload + ), + ) + else: + self._requester = Thread( + target=self._request_unimode_rest, + kwargs=dict( + tgraph = self._graph, + query_name = self.query_name, + read_task_q = self._read_task_q, + timeout = self.timeout, + payload = self._payload + ), + ) self._requester.start() def _start(self) -> None: @@ -1730,20 +1751,27 @@ def data(self) -> Any: return self def _reset(self, theend=False) -> None: - logging.debug("Resetting the loader") + logger.debug("Resetting the data loader") if self._exit_event: self._exit_event.set() if self._request_task_q: - self._request_task_q.put(None) + while True: + try: + self._request_task_q.get(block=False) + except Empty: + break if self._download_task_q: - self._download_task_q.put(None) + while True: + try: + self._download_task_q.get(block=False) + except Empty: + break if self._read_task_q: while True: try: self._read_task_q.get(block=False) except Empty: break - self._read_task_q.put(None) if self._data_q: while True: try: @@ -1787,7 +1815,7 @@ def _reset(self, theend=False) -> None: "Failed to delete topic {}".format(del_res["topic"]) ) self._kafka_topic = None - logging.debug("Successfully reset the loader") + logger.debug("Successfully reset the data loader") def _generate_attribute_string(self, schema_type, attr_names, attr_types) -> str: if schema_type.lower() == "vertex": From e4e407ea1bbb4dec30a60694d312080161fa299c Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Mon, 16 Oct 2023 16:38:04 -0700 Subject: [PATCH 06/36] refactor(VertexLoader): update gsql template --- pyTigerGraph/gds/dataloaders.py | 45 +++-- .../gds/gsql/dataloaders/vertex_loader.gsql | 156 ++++++------------ 2 files changed, 82 insertions(+), 119 deletions(-) diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index d77b3d35..0e972803 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -2658,7 +2658,7 @@ class VertexLoader(BaseLoader): print("----Batch {}: Shape {}----".format(i, batch.shape)) print(batch.head(1)) <1> ---- - <1> Since the example does not provide an output format, the output format defaults to panda frames, have access to the methods of panda frame instances. + <1> The output format is Pandas dataframe. -- Output:: + @@ -2811,7 +2811,11 @@ def __init__( self._vtypes = sorted(self._vtypes) # Initialize parameters for the query if batch_size: - # If batch_size is given, calculate the number of batches + # batch size takes precedence over number of batches + self.batch_size = batch_size + self.num_batches = None + else: + # If number of batches is given, calculate batch size num_vertices_by_type = self._graph.getVertexCount(self._vtypes) if filter_by: num_vertices = sum( @@ -2820,20 +2824,13 @@ def __init__( ) else: num_vertices = sum(num_vertices_by_type.values()) - self.num_batches = math.ceil(num_vertices / batch_size) - else: - # Otherwise, take the number of batches as is. + self.batch_size = math.ceil(num_vertices / num_batches) self.num_batches = num_batches - self._payload["num_batches"] = self.num_batches if filter_by: self._payload["filter_by"] = filter_by - if batch_size: - self._payload["batch_size"] = batch_size - self._payload["shuffle"] = shuffle self._payload["delimiter"] = delimiter self._payload["v_types"] = self._vtypes self._payload["input_vertices"] = [] - self._payload["num_heap_inserts"] = self.num_heap_inserts # Install query self.query_name = self._install_query() @@ -2849,31 +2846,43 @@ def _install_query(self, force: bool = False) -> str: if isinstance(self.attributes, dict): # Multiple vertex types - print_query = "" + print_query_kafka = "" + print_query_http = "" for idx, vtype in enumerate(self._vtypes): v_attr_names = self.attributes.get(vtype, []) v_attr_types = self._v_schema[vtype] if v_attr_names: print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + {} + "\\n")\n'.format( + print_query_http += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + {} + "\\n")\n'.format( + "IF" if idx==0 else "ELSE IF", vtype, print_attr) + print_query_kafka += '{} s.type == "{}" THEN \n s.@v_data += (s.type + delimiter + stringify(getvid(s)) + delimiter + {} + "\\n")\n'.format( "IF" if idx==0 else "ELSE IF", vtype, print_attr) else: - print_query += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + "\\n")\n'.format( + print_query_http += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + "\\n")\n'.format( "IF" if idx==0 else "ELSE IF", vtype) - print_query += "END" - query_replace["{VERTEXATTRS}"] = print_query + print_query_kafka += '{} s.type == "{}" THEN \n s.@v_data += (s.type + delimiter + stringify(getvid(s)) + "\\n")\n'.format( + "IF" if idx==0 else "ELSE IF", vtype) + print_query_http += "END" + print_query_kafka += "END" + query_replace["{VERTEXATTRSHTTP}"] = print_query_http + query_replace["{VERTEXATTRSKAFKA}"] = print_query_kafka else: # Ignore vertex types v_attr_names = self.attributes v_attr_types = next(iter(self._v_schema.values())) if v_attr_names: print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query = '@@v_batch += (stringify(getvid(s)) + delimiter + {} + "\\n")'.format( + print_query_http = '@@v_batch += (stringify(getvid(s)) + delimiter + {} + "\\n")'.format( + print_attr + ) + print_query_kafka = 's.@v_data += (stringify(getvid(s)) + delimiter + {} + "\\n")'.format( print_attr ) else: - print_query = '@@v_batch += (stringify(getvid(s)) + "\\n")' - query_replace["{VERTEXATTRS}"] = print_query + print_query_http = '@@v_batch += (stringify(getvid(s)) + "\\n")' + print_query_kafka = 's.@v_data += (stringify(getvid(s)) + "\\n")' + query_replace["{VERTEXATTRSHTTP}"] = print_query_http + query_replace["{VERTEXATTRSKAFKA}"] = print_query_kafka # Install query query_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), diff --git a/pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql b/pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql index 76ffaddd..3af32f2f 100644 --- a/pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql +++ b/pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql @@ -1,8 +1,5 @@ CREATE QUERY vertex_loader_{QUERYSUFFIX}( SET input_vertices, - INT batch_size, - INT num_batches=1, - BOOL shuffle=FALSE, STRING filter_by, SET v_types, STRING delimiter, @@ -22,8 +19,7 @@ CREATE QUERY vertex_loader_{QUERYSUFFIX}( STRING ssl_endpoint_identification_algorithm="", STRING sasl_kerberos_service_name="", STRING sasl_kerberos_keytab="", - STRING sasl_kerberos_principal="", - INT num_heap_inserts = 10 + STRING sasl_kerberos_principal="" ) SYNTAX V2 { /* This query generates batches of vertices. If `input_vertices` is given, it will generate @@ -45,108 +41,66 @@ CREATE QUERY vertex_loader_{QUERYSUFFIX}( sasl_password : SASL password for Kafka. ssl_ca_location: Path to CA certificate for verifying the Kafka broker key. */ - TYPEDEF TUPLE ID_Tuple; - INT num_vertices; - INT kafka_errcode; - SumAccum @tmp_id; - SumAccum @@kafka_error; - UINT producer; - INT batch_s; - OrAccum @prev_sampled; - - # Initialize Kafka producer - IF kafka_address != "" THEN - producer = init_kafka_producer( - kafka_address, kafka_max_size, security_protocol, - sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, - ssl_certificate_location, ssl_key_location, ssl_key_password, - ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, - sasl_kerberos_keytab, sasl_kerberos_principal); - END; - - # Shuffle vertex ID if needed - start = {v_types}; - IF filter_by IS NOT NULL THEN - start = SELECT s FROM start:s WHERE s.getAttr(filter_by, "BOOL"); - END; - IF shuffle THEN - num_vertices = start.size(); - res = SELECT s - FROM start:s - POST-ACCUM s.@tmp_id = floor(rand()*num_vertices); - ELSE - res = SELECT s - FROM start:s - POST-ACCUM s.@tmp_id = getvid(s); - END; - IF batch_size IS NULL THEN - batch_s = ceil(res.size()/num_batches); - ELSE - batch_s = batch_size; - END; - # Generate batches - FOREACH batch_id IN RANGE[0, num_batches-1] DO - SumAccum @@v_batch; - SetAccum @@seeds; - IF input_vertices.size()==0 THEN - start = {v_types}; - HeapAccum (1, tmp_id ASC) @@batch_heap; - @@batch_heap.resize(batch_s); - IF filter_by IS NOT NULL THEN - FOREACH iter IN RANGE[0,num_heap_inserts-1] DO - _verts = SELECT s FROM start:s - WHERE s.@tmp_id % num_heap_inserts == iter AND NOT s.@prev_sampled AND s.getAttr(filter_by, "BOOL") - POST-ACCUM @@batch_heap += ID_Tuple(s.@tmp_id, s); - END; - FOREACH elem IN @@batch_heap DO - @@seeds += elem.v; - END; - seeds = {@@seeds}; - seeds = SELECT s - FROM seeds:s - POST-ACCUM - s.@prev_sampled += TRUE, - {VERTEXATTRS}; - ELSE - FOREACH iter IN RANGE[0,num_heap_inserts-1] DO - _verts = SELECT s FROM start:s - WHERE s.@tmp_id % num_heap_inserts == iter AND NOT s.@prev_sampled - POST-ACCUM @@batch_heap += ID_Tuple(s.@tmp_id, s); - END; - FOREACH elem IN @@batch_heap DO - @@seeds += elem.v; - END; - seeds = {@@seeds}; - seeds = SELECT s - FROM start:s - POST-ACCUM - s.@prev_sampled += TRUE, - {VERTEXATTRS}; - END; - ELSE - start = input_vertices; - seeds = SELECT s - FROM start:s - POST-ACCUM - {VERTEXATTRS}; - END; + # If getting all vertices of given types + IF input_vertices.size()==0 THEN + # Filter seeds if needed. + start = {v_types}; + seeds = SELECT s + FROM start:s + WHERE filter_by is NULL OR s.getAttr(filter_by, "BOOL"); + # Generate batches + # If using kafka to export IF kafka_address != "" THEN - # Write to kafka - kafka_errcode = write_to_kafka(producer, kafka_topic, batch_id%kafka_topic_partitions, "vertex_batch_" + stringify(batch_id), @@v_batch); + SumAccum @@kafka_error; + SumAccum @v_data; + + # Initialize Kafka producer + UINT producer = init_kafka_producer( + kafka_address, kafka_max_size, security_protocol, + sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, + ssl_certificate_location, ssl_key_location, ssl_key_password, + ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, + sasl_kerberos_keytab, sasl_kerberos_principal); + + res = SELECT s + FROM seeds:s + POST-ACCUM + {VERTEXATTRSKAFKA}, + INT kafka_errcode = write_to_kafka(producer, kafka_topic, getvid(s)%kafka_topic_partitions, "vertex_" + stringify(getvid(s)), s.@v_data), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending data for vertex " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\n"), + END + s.@v_data = "" + LIMIT 1; + + INT kafka_errcode = close_kafka_producer(producer, kafka_timeout); IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending vertex batch " + stringify(batch_id) + ": "+ stringify(kafka_errcode) + "\n"); + @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n"); END; + PRINT @@kafka_error as kafkaError; + # Else return as http response ELSE - # Add to response - PRINT @@v_batch AS vertex_batch; + ListAccum @@v_batch; + res = SELECT s + FROM seeds:s + ACCUM + {VERTEXATTRSHTTP}; + + FOREACH i IN @@v_batch DO + PRINT i as data_batch; + END; END; - END; - IF kafka_address != "" THEN - kafka_errcode = close_kafka_producer(producer, kafka_timeout); - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n"); + # Else get given vertices. + ELSE + ListAccum @@v_batch; + start = input_vertices; + seeds = SELECT s + FROM start:s + POST-ACCUM + {VERTEXATTRSHTTP}; + FOREACH i IN @@v_batch DO + PRINT i as data_batch; END; - PRINT @@kafka_error as kafkaError; END; } \ No newline at end of file From 4c42d4abe13d6b1494c805f9619d1f437a98d232 Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Mon, 16 Oct 2023 16:49:44 -0700 Subject: [PATCH 07/36] refactor(VertexLoader): update reader thread --- pyTigerGraph/gds/dataloaders.py | 37 +++++++++++++-------------------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index 0e972803..af99d49a 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -2898,7 +2898,7 @@ def _start(self) -> None: self._data_q = Queue(self.buffer_size) self._exit_event = Event() - self._start_request(False, "vertex") + self._start_request(False) # Start reading thread. if not self.is_hetero: @@ -2906,27 +2906,20 @@ def _start(self) -> None: else: v_attr_types = self._v_schema self._reader = Thread( - target=self._read_data, - args=( - self._exit_event, - self._read_task_q, - self._data_q, - "vertex", - self.output_format, - self.attributes, - {} if self.is_hetero else [], - {} if self.is_hetero else [], - v_attr_types, - [], - [], - [], - {}, - False, - self.delimiter, - False, - self.is_hetero, - self.callback_fn - ), + target=self._read_vertex_data, + kwargs=dict( + exit_event = self._exit_event, + in_q = self._read_task_q, + out_q = self._data_q, + batch_size = self.batch_size, + v_in_feats = self.attributes, + v_out_labels = {} if self.is_hetero else [], + v_extra_feats = {} if self.is_hetero else [], + v_attr_types = v_attr_types, + delimiter = self.delimiter, + is_hetero = self.is_hetero, + callback_fn = self.callback_fn + ) ) self._reader.start() From 41c45f9e1662282a5a785f2931b0bf5381781821 Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Tue, 17 Oct 2023 13:22:05 -0700 Subject: [PATCH 08/36] fix(VertexLoader): fix wrong comma in gsql --- pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql b/pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql index 3af32f2f..500a0b36 100644 --- a/pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql +++ b/pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql @@ -69,8 +69,8 @@ CREATE QUERY vertex_loader_{QUERYSUFFIX}( {VERTEXATTRSKAFKA}, INT kafka_errcode = write_to_kafka(producer, kafka_topic, getvid(s)%kafka_topic_partitions, "vertex_" + stringify(getvid(s)), s.@v_data), IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending data for vertex " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\n"), - END + @@kafka_error += ("Error sending data for vertex " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\n") + END, s.@v_data = "" LIMIT 1; From ce95a216b5eccf4c7fe236e3c627252eed8cd118 Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Tue, 17 Oct 2023 13:22:45 -0700 Subject: [PATCH 09/36] fix(BaseLoader): fix missing function decorator --- pyTigerGraph/gds/dataloaders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index af99d49a..40a70384 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -622,6 +622,7 @@ def _download_graph_kafka( "Unrecognized key {} for messages in kafka".format(key) ) + @staticmethod def _download_unimode_kafka( exit_event: Event, read_task_q: Queue, From a1ff95b8be5acec43501870bff772b4af877f706 Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Wed, 18 Oct 2023 15:48:19 -0700 Subject: [PATCH 10/36] fix(BaseLoader,VertexLoader): change how we shuffle --- pyTigerGraph/gds/dataloaders.py | 93 +++++++++++++++++++++++---------- 1 file changed, 66 insertions(+), 27 deletions(-) diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index 40a70384..59b6cd5e 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -53,6 +53,7 @@ def __init__( graph: "TigerGraphConnection", loader_id: str = None, num_batches: int = 1, + shuffle: bool = False, buffer_size: int = 4, output_format: str = "dataframe", reverse_edge: bool = False, @@ -185,7 +186,8 @@ def __init__( # Queues to store tasks and data self._request_task_q = None self._download_task_q = None - self._read_task_q = None + self._read_task_q1 = None + self._read_task_q2 = None self._data_q = None self._kafka_topic = None self._all_kafka_topics = set() @@ -237,7 +239,7 @@ def __init__( self._iterator = False self.callback_fn = callback_fn self.distributed_query = distributed_query - self.num_heap_inserts = 10 + self.shuffle = shuffle # Kafka consumer and admin self.max_kafka_msg_size = Kafka_max_msg_size self.kafka_address_consumer = ( @@ -552,7 +554,9 @@ def _request_kafka( def _request_graph_rest( tgraph: "TigerGraphConnection", query_name: str, - read_task_q: Queue, + read_task_q1: Queue, + read_task_q2: Queue, + shuffle: bool = False, timeout: int = 600000, payload: dict = {}, ) -> NoReturn: @@ -560,25 +564,42 @@ def _request_graph_rest( resp = tgraph.runInstalledQuery( query_name, params=payload, timeout=timeout, usePost=True ) - # Put raw data into reading queue + # Put raw data into reading queue. + # If shuffle, randomly choose between the two queues. + # Otherwise, put all into queue 1. for i in resp: - read_task_q.put((i["vertex_batch"], i["edge_batch"])) + if shuffle and random.random>0.5: + read_task_q2.put((i["vertex_batch"], i["edge_batch"])) + else: + read_task_q1.put((i["vertex_batch"], i["edge_batch"])) + read_task_q1.put(None) + read_task_q2.put(None) @staticmethod def _request_unimode_rest( tgraph: "TigerGraphConnection", query_name: str, - read_task_q: Queue, + read_task_q1: Queue, + read_task_q2: Queue, + shuffle: bool = False, timeout: int = 600000, payload: dict = {}, ) -> NoReturn: # Run query + #TODO: check what happens when the query times out resp = tgraph.runInstalledQuery( query_name, params=payload, timeout=timeout, usePost=True ) # Put raw data into reading queue + # If shuffle, randomly choose between the two queues. + # Otherwise, put all into queue 1. for i in resp: - read_task_q.put(i["data_batch"]) + if shuffle and random.random()>0.5: + read_task_q2.put(i["data_batch"]) + else: + read_task_q1.put(i["data_batch"]) + read_task_q1.put(None) + read_task_q2.put(None) @staticmethod def _download_graph_kafka( @@ -822,10 +843,9 @@ def _read_graph_data( @staticmethod def _read_vertex_data( exit_event: Event, - in_q: Queue, + in_q: List[Queue], out_q: Queue, batch_size: int, - shuffle: bool = False, v_in_feats: Union[list, dict] = [], v_out_labels: Union[list, dict] = [], v_extra_feats: Union[list, dict] = [], @@ -835,20 +855,27 @@ def _read_vertex_data( callback_fn: Callable = None ) -> NoReturn: buffer = [] - while not exit_event.is_set(): + curr_q_idx = 0 + curr_q = in_q[curr_q_idx] + is_q_empty = [False]*len(in_q) + last_batch = False + while (not exit_event.is_set()) and (not all(is_q_empty)): try: - raw = in_q.get(timeout=1) + raw = curr_q.get(timeout=0.5) except Empty: + next_q_idx = 1 - curr_q_idx + if not is_q_empty[next_q_idx]: + curr_q_idx = next_q_idx + curr_q = in_q[curr_q_idx] continue - # if shuffle the data, 50% chance to save this data point for later - if shuffle and (random.random() < 0.5): - in_q.task_done() - in_q.put(raw) - continue - # Store raw into buffer until there are enough data points for a batch - buffer.append(raw) - in_q.task_done() - if len(buffer) < batch_size: + if raw is None: + is_q_empty[curr_q_idx] = True + if all(is_q_empty): + if len(buffer) > 0: + last_batch = True + else: + buffer.append(raw) + if (len(buffer) < batch_size) and (not last_batch): continue try: data = BaseLoader._parse_vertex_data( @@ -877,6 +904,7 @@ def _read_vertex_data( logger.error("Parameters:\n v_in_feats={}\n v_out_labels={}\n v_extra_feats={}\n v_attr_types={}\n delimiter={}\n".format( v_in_feats, v_out_labels, v_extra_feats, v_attr_types, delimiter)) buffer.clear() + out_q.put(None) @staticmethod def _read_edge_data( @@ -1676,7 +1704,9 @@ def _start_request(self, is_graph: bool): kwargs=dict( tgraph = self._graph, query_name = self.query_name, - read_task_q = self._read_task_q, + read_task_q1 = self._read_task_q1, + read_task_q2 = self._read_task_q2, + shuffle = self.shuffle, timeout = self.timeout, payload = self._payload ), @@ -1767,10 +1797,16 @@ def _reset(self, theend=False) -> None: self._download_task_q.get(block=False) except Empty: break - if self._read_task_q: + if self._read_task_q1: while True: try: - self._read_task_q.get(block=False) + self._read_task_q1.get(block=False) + except Empty: + break + if self._read_task_q2: + while True: + try: + self._read_task_q2.get(block=False) except Empty: break if self._data_q: @@ -1785,14 +1821,15 @@ def _reset(self, theend=False) -> None: self._downloader.join() if self._reader: self._reader.join() - del self._request_task_q, self._download_task_q, self._read_task_q, self._data_q + del self._request_task_q, self._download_task_q, self._read_task_q1, self._read_task_q2, self._data_q self._exit_event = None self._requester, self._downloader, self._reader = None, None, None - self._request_task_q, self._download_task_q, self._read_task_q, self._data_q = ( + self._request_task_q, self._download_task_q, self._read_task_q1, self._read_task_q2, self._data_q = ( None, None, None, None, + None ) if theend: if self._kafka_topic and self._kafka_consumer: @@ -2760,6 +2797,7 @@ def __init__( graph, loader_id, num_batches, + shuffle, buffer_size, output_format, reverse_edge, @@ -2895,7 +2933,8 @@ def _install_query(self, force: bool = False) -> str: def _start(self) -> None: # Create task and result queues - self._read_task_q = Queue(self.buffer_size * 2) + self._read_task_q1 = Queue(self.buffer_size) + self._read_task_q2 = Queue(self.buffer_size) self._data_q = Queue(self.buffer_size) self._exit_event = Event() @@ -2910,7 +2949,7 @@ def _start(self) -> None: target=self._read_vertex_data, kwargs=dict( exit_event = self._exit_event, - in_q = self._read_task_q, + in_q = [self._read_task_q1, self._read_task_q2], out_q = self._data_q, batch_size = self.batch_size, v_in_feats = self.attributes, From e7f3e94016a84b81810a9206b076bad4cd6c68b3 Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Wed, 18 Oct 2023 16:09:17 -0700 Subject: [PATCH 11/36] feat(BaseLoader,VertexLoader): fix when to stop kafka --- pyTigerGraph/gds/dataloaders.py | 18 +++++++++--------- .../gds/gsql/dataloaders/vertex_loader.gsql | 7 +++++++ 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index 59b6cd5e..f30f37b1 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -647,21 +647,21 @@ def _download_graph_kafka( def _download_unimode_kafka( exit_event: Event, read_task_q: Queue, - kafka_consumer: "KafkaConsumer", - max_wait_time: int = 300 + kafka_consumer: "KafkaConsumer" ) -> NoReturn: - delivered_batch = 0 - wait_time = 0 - while (not exit_event.is_set()) and (wait_time < max_wait_time): + empty = False + while (not exit_event.is_set()) and (not empty): resp = kafka_consumer.poll(1000) if not resp: - wait_time += 1 continue - wait_time = 0 for msgs in resp.values(): for message in msgs: - read_task_q.put(message.value.decode("utf-8")) - delivered_batch += 1 + key = message.key.decode("utf-8") + if key == "STOP": + read_task_q.put(None) + empty = True + else: + read_task_q.put(message.value.decode("utf-8")) @staticmethod def _read_graph_data( diff --git a/pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql b/pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql index 500a0b36..70604768 100644 --- a/pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql +++ b/pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql @@ -73,6 +73,13 @@ CREATE QUERY vertex_loader_{QUERYSUFFIX}( END, s.@v_data = "" LIMIT 1; + + FOREACH i IN RANGE[0, kafka_topic_partitions-1] DO + INT kafka_errcode = write_to_kafka(producer, kafka_topic, i, "STOP", ""); + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending STOP signal to topic partition " + stringify(i) + ": " + stringify(kafka_errcode) + "\n"); + END; + END; INT kafka_errcode = close_kafka_producer(producer, kafka_timeout); IF kafka_errcode!=0 THEN From 1dfa8bd3622d1dae12d97ab26edfe11843f8fb1a Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Wed, 18 Oct 2023 16:35:20 -0700 Subject: [PATCH 12/36] feat(VertexLoader): update shuffle --- pyTigerGraph/gds/dataloaders.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index f30f37b1..6691db67 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -646,8 +646,10 @@ def _download_graph_kafka( @staticmethod def _download_unimode_kafka( exit_event: Event, - read_task_q: Queue, - kafka_consumer: "KafkaConsumer" + read_task_q1: Queue, + read_task_q2: Queue, + kafka_consumer: "KafkaConsumer", + shuffle: bool = False ) -> NoReturn: empty = False while (not exit_event.is_set()) and (not empty): @@ -658,10 +660,14 @@ def _download_unimode_kafka( for message in msgs: key = message.key.decode("utf-8") if key == "STOP": - read_task_q.put(None) + read_task_q1.put(None) + read_task_q2.put(None) empty = True else: - read_task_q.put(message.value.decode("utf-8")) + if shuffle and random.random()>0.5: + read_task_q2.put(message.value.decode("utf-8")) + else: + read_task_q1.put(message.value.decode("utf-8")) @staticmethod def _read_graph_data( @@ -1666,9 +1672,10 @@ def _start_request(self, is_graph: bool): target=self._download_unimode_kafka, kwargs=dict( exit_event = self._exit_event, - read_task_q = self._read_task_q, + read_task_q1 = self._read_task_q1, + read_task_q2 = self._read_task_q2, kafka_consumer = self._kafka_consumer, - max_wait_time = self.timeout + shuffle = self.shuffle ), ) self._downloader.start() From d5a97dfed6e5989c057e501e1fba9a8c4b7d5515 Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Wed, 18 Oct 2023 16:37:20 -0700 Subject: [PATCH 13/36] test(VertexLoader): update tests --- tests/test_gds_VertexLoader.py | 224 +++++++++++++++++++++++++-------- 1 file changed, 172 insertions(+), 52 deletions(-) diff --git a/tests/test_gds_VertexLoader.py b/tests/test_gds_VertexLoader.py index 3b76890d..e92a0d77 100644 --- a/tests/test_gds_VertexLoader.py +++ b/tests/test_gds_VertexLoader.py @@ -7,7 +7,7 @@ from pyTigerGraph.gds.utilities import is_query_installed -class TestGDSVertexLoader(unittest.TestCase): +class TestGDSVertexLoaderKafka(unittest.TestCase): @classmethod def setUpClass(cls): cls.conn = make_connection(graphname="Cora") @@ -19,12 +19,10 @@ def test_init(self): batch_size=16, shuffle=True, filter_by="train_mask", - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 9) + self.assertIsNone(loader.num_batches) def test_iterate(self): loader = VertexLoader( @@ -33,21 +31,25 @@ def test_iterate(self): batch_size=16, shuffle=True, filter_by="train_mask", - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) num_batches = 0 + batch_sizes = [] for data in loader: - # print(num_batches, data.head()) + # print(num_batches, data.shape, data.head()) self.assertIsInstance(data, DataFrame) self.assertIn("x", data.columns) self.assertIn("y", data.columns) self.assertIn("train_mask", data.columns) self.assertIn("val_mask", data.columns) self.assertIn("test_mask", data.columns) + self.assertEqual(data.shape[1], 6) + batch_sizes.append(data.shape[0]) num_batches += 1 self.assertEqual(num_batches, 9) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) def test_all_vertices(self): loader = VertexLoader( @@ -56,8 +58,6 @@ def test_all_vertices(self): num_batches=1, shuffle=False, filter_by="train_mask", - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) data = loader.data @@ -68,6 +68,29 @@ def test_all_vertices(self): self.assertIn("train_mask", data.columns) self.assertIn("val_mask", data.columns) self.assertIn("test_mask", data.columns) + self.assertEqual(data.shape[0], 140) + self.assertEqual(data.shape[1], 6) + + def test_all_vertices_multichar_delimiter(self): + loader = VertexLoader( + graph=self.conn, + attributes=["x", "y", "train_mask", "val_mask", "test_mask"], + num_batches=1, + shuffle=False, + filter_by="train_mask", + delimiter="$|", + kafka_address="kafka:9092", + ) + data = loader.data + # print(data) + self.assertIsInstance(data, DataFrame) + self.assertIn("x", data.columns) + self.assertIn("y", data.columns) + self.assertIn("train_mask", data.columns) + self.assertIn("val_mask", data.columns) + self.assertIn("test_mask", data.columns) + self.assertEqual(data.shape[0], 140) + self.assertEqual(data.shape[1], 6) def test_sasl_plaintext(self): loader = VertexLoader( @@ -126,6 +149,97 @@ def test_sasl_ssl(self): num_batches += 1 self.assertEqual(num_batches, 9) + +class TestGDSHeteroVertexLoaderKafka(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.conn = make_connection(graphname="hetero") + + def test_init(self): + loader = VertexLoader( + graph=self.conn, + attributes={"v0": ["x", "y"], + "v1": ["x"]}, + batch_size=20, + shuffle=True, + kafka_address="kafka:9092" + ) + self.assertTrue(is_query_installed(self.conn, loader.query_name)) + self.assertIsNone(loader.num_batches) + + def test_iterate(self): + loader = VertexLoader( + graph=self.conn, + attributes={"v0": ["x", "y"], + "v1": ["x"]}, + batch_size=20, + shuffle=True, + kafka_address="kafka:9092" + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + batchsize = 0 + if "v0" in data: + self.assertIsInstance(data["v0"], DataFrame) + self.assertIn("x", data["v0"].columns) + self.assertIn("y", data["v0"].columns) + batchsize += data["v0"].shape[0] + self.assertEqual(data["v0"].shape[1], 3) + if "v1" in data: + self.assertIsInstance(data["v1"], DataFrame) + self.assertIn("x", data["v1"].columns) + batchsize += data["v1"].shape[0] + self.assertEqual(data["v1"].shape[1], 2) + self.assertGreater(len(data), 0) + num_batches += 1 + batch_sizes.append(batchsize) + self.assertEqual(num_batches, 10) + for i in batch_sizes[:-1]: + self.assertEqual(i, 20) + self.assertLessEqual(batch_sizes[-1], 20) + + def test_all_vertices(self): + loader = VertexLoader( + graph=self.conn, + attributes={"v0": ["x", "y"], + "v1": ["x"]}, + num_batches=1, + shuffle=False, + kafka_address="kafka:9092" + ) + data = loader.data + # print(data) + self.assertIsInstance(data["v0"], DataFrame) + self.assertTupleEqual(data["v0"].shape, (76, 3)) + self.assertIsInstance(data["v1"], DataFrame) + self.assertTupleEqual(data["v1"].shape, (110, 2)) + self.assertIn("x", data["v0"].columns) + self.assertIn("y", data["v0"].columns) + self.assertIn("x", data["v1"].columns) + + def test_all_vertices_multichar_delimiter(self): + loader = VertexLoader( + graph=self.conn, + attributes={"v0": ["x", "y"], + "v1": ["x"]}, + num_batches=1, + shuffle=False, + delimiter="|$", + kafka_address="kafka:9092" + ) + data = loader.data + # print(data) + self.assertIsInstance(data["v0"], DataFrame) + self.assertTupleEqual(data["v0"].shape, (76, 3)) + self.assertIsInstance(data["v1"], DataFrame) + self.assertTupleEqual(data["v1"].shape, (110, 2)) + self.assertIn("x", data["v0"].columns) + self.assertIn("y", data["v0"].columns) + self.assertIn("x", data["v1"].columns) + + class TestGDSVertexLoaderREST(unittest.TestCase): @classmethod def setUpClass(cls): @@ -137,12 +251,10 @@ def test_init(self): attributes=["x", "y", "train_mask", "val_mask", "test_mask"], batch_size=16, shuffle=True, - filter_by="train_mask", - loader_id=None, - buffer_size=4, + filter_by="train_mask" ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 9) + self.assertIsNone(loader.num_batches) def test_iterate(self): loader = VertexLoader( @@ -150,21 +262,25 @@ def test_iterate(self): attributes=["x", "y", "train_mask", "val_mask", "test_mask"], batch_size=16, shuffle=True, - filter_by="train_mask", - loader_id=None, - buffer_size=4, + filter_by="train_mask" ) num_batches = 0 + batch_sizes = [] for data in loader: - # print(num_batches, data.head()) + # print(num_batches, data.shape, data.head()) self.assertIsInstance(data, DataFrame) self.assertIn("x", data.columns) self.assertIn("y", data.columns) self.assertIn("train_mask", data.columns) self.assertIn("val_mask", data.columns) self.assertIn("test_mask", data.columns) + self.assertEqual(data.shape[1], 6) + batch_sizes.append(data.shape[0]) num_batches += 1 self.assertEqual(num_batches, 9) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) def test_all_vertices(self): loader = VertexLoader( @@ -172,9 +288,7 @@ def test_all_vertices(self): attributes=["x", "y", "train_mask", "val_mask", "test_mask"], num_batches=1, shuffle=False, - filter_by="train_mask", - loader_id=None, - buffer_size=4, + filter_by="train_mask" ) data = loader.data # print(data) @@ -184,6 +298,8 @@ def test_all_vertices(self): self.assertIn("train_mask", data.columns) self.assertIn("val_mask", data.columns) self.assertIn("test_mask", data.columns) + self.assertEqual(data.shape[0], 140) + self.assertEqual(data.shape[1], 6) def test_all_vertices_multichar_delimiter(self): loader = VertexLoader( @@ -192,8 +308,6 @@ def test_all_vertices_multichar_delimiter(self): num_batches=1, shuffle=False, filter_by="train_mask", - loader_id=None, - buffer_size=4, delimiter="$|" ) data = loader.data @@ -204,6 +318,8 @@ def test_all_vertices_multichar_delimiter(self): self.assertIn("train_mask", data.columns) self.assertIn("val_mask", data.columns) self.assertIn("test_mask", data.columns) + self.assertEqual(data.shape[0], 140) + self.assertEqual(data.shape[1], 6) def test_string_attr(self): conn = make_connection(graphname="Social") @@ -212,14 +328,13 @@ def test_string_attr(self): graph=conn, attributes=["age", "state"], num_batches=1, - shuffle=False, - loader_id=None, - buffer_size=4, + shuffle=False ) data = loader.data # print(data) self.assertIsInstance(data, DataFrame) self.assertEqual(data.shape[0], 7) + self.assertEqual(data.shape[1], 3) self.assertIn("age", data.columns) self.assertIn("state", data.columns) @@ -235,13 +350,10 @@ def test_init(self): attributes={"v0": ["x", "y"], "v1": ["x"]}, batch_size=20, - shuffle=True, - filter_by=None, - loader_id=None, - buffer_size=4, + shuffle=True ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 10) + self.assertIsNone(loader.num_batches) def test_iterate(self): loader = VertexLoader( @@ -249,21 +361,31 @@ def test_iterate(self): attributes={"v0": ["x", "y"], "v1": ["x"]}, batch_size=20, - shuffle=True, - filter_by=None, - loader_id=None, - buffer_size=4, + shuffle=True ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) - self.assertIsInstance(data["v0"], DataFrame) - self.assertIsInstance(data["v1"], DataFrame) - self.assertIn("x", data["v0"].columns) - self.assertIn("y", data["v0"].columns) - self.assertIn("x", data["v1"].columns) + batchsize = 0 + if "v0" in data: + self.assertIsInstance(data["v0"], DataFrame) + self.assertIn("x", data["v0"].columns) + self.assertIn("y", data["v0"].columns) + batchsize += data["v0"].shape[0] + self.assertEqual(data["v0"].shape[1], 3) + if "v1" in data: + self.assertIsInstance(data["v1"], DataFrame) + self.assertIn("x", data["v1"].columns) + batchsize += data["v1"].shape[0] + self.assertEqual(data["v1"].shape[1], 2) + self.assertGreater(len(data), 0) num_batches += 1 + batch_sizes.append(batchsize) self.assertEqual(num_batches, 10) + for i in batch_sizes[:-1]: + self.assertEqual(i, 20) + self.assertLessEqual(batch_sizes[-1], 20) def test_all_vertices(self): loader = VertexLoader( @@ -271,10 +393,7 @@ def test_all_vertices(self): attributes={"v0": ["x", "y"], "v1": ["x"]}, num_batches=1, - shuffle=False, - filter_by=None, - loader_id=None, - buffer_size=4, + shuffle=False ) data = loader.data # print(data) @@ -293,9 +412,6 @@ def test_all_vertices_multichar_delimiter(self): "v1": ["x"]}, num_batches=1, shuffle=False, - filter_by=None, - loader_id=None, - buffer_size=4, delimiter="|$" ) data = loader.data @@ -311,11 +427,6 @@ def test_all_vertices_multichar_delimiter(self): if __name__ == "__main__": suite = unittest.TestSuite() - suite.addTest(TestGDSVertexLoader("test_init")) - suite.addTest(TestGDSVertexLoader("test_iterate")) - suite.addTest(TestGDSVertexLoader("test_all_vertices")) - # suite.addTest(TestGDSVertexLoader("test_sasl_plaintext")) - # suite.addTest(TestGDSVertexLoader("test_sasl_ssl")) suite.addTest(TestGDSVertexLoaderREST("test_init")) suite.addTest(TestGDSVertexLoaderREST("test_iterate")) suite.addTest(TestGDSVertexLoaderREST("test_all_vertices")) @@ -325,6 +436,15 @@ def test_all_vertices_multichar_delimiter(self): suite.addTest(TestGDSHeteroVertexLoaderREST("test_iterate")) suite.addTest(TestGDSHeteroVertexLoaderREST("test_all_vertices")) suite.addTest(TestGDSHeteroVertexLoaderREST("test_all_vertices_multichar_delimiter")) - + suite.addTest(TestGDSVertexLoaderKafka("test_init")) + suite.addTest(TestGDSVertexLoaderKafka("test_iterate")) + suite.addTest(TestGDSVertexLoaderKafka("test_all_vertices")) + suite.addTest(TestGDSVertexLoaderKafka("test_all_vertices_multichar_delimiter")) + # suite.addTest(TestGDSVertexLoaderKafka("test_sasl_plaintext")) + # suite.addTest(TestGDSVertexLoaderKafka("test_sasl_ssl")) + suite.addTest(TestGDSHeteroVertexLoaderKafka("test_init")) + suite.addTest(TestGDSHeteroVertexLoaderKafka("test_iterate")) + suite.addTest(TestGDSHeteroVertexLoaderKafka("test_all_vertices")) + suite.addTest(TestGDSHeteroVertexLoaderKafka("test_all_vertices_multichar_delimiter")) runner = unittest.TextTestRunner(verbosity=2, failfast=True) runner.run(suite) From 0298bc02308cc776968efa5d6da814c1084f2d4c Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Thu, 19 Oct 2023 10:57:42 -0700 Subject: [PATCH 14/36] feat(VertexLoader): move kafka commands into template --- pyTigerGraph/gds/dataloaders.py | 52 ++++++++++++++----- .../gds/gsql/dataloaders/vertex_loader.gsql | 30 +++++------ 2 files changed, 50 insertions(+), 32 deletions(-) diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index 6691db67..389f773d 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -2899,17 +2899,33 @@ def _install_query(self, force: bool = False) -> str: v_attr_types = self._v_schema[vtype] if v_attr_names: print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query_http += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + {} + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype, print_attr) - print_query_kafka += '{} s.type == "{}" THEN \n s.@v_data += (s.type + delimiter + stringify(getvid(s)) + delimiter + {} + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype, print_attr) + print_query_http += """ + {} s.type == "{}" THEN + @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + {} + "\\n")\n"""\ + .format("IF" if idx==0 else "ELSE IF", vtype, print_attr) + print_query_kafka += """ + {} s.type == "{}" THEN + STRING msg = (s.type + delimiter + stringify(getvid(s)) + delimiter + {} + "\\n"), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, getvid(s)%kafka_topic_partitions, "vertex_" + stringify(getvid(s)), msg), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending data for vertex " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\\n") + END\n""".format("IF" if idx==0 else "ELSE IF", vtype, print_attr) else: - print_query_http += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype) - print_query_kafka += '{} s.type == "{}" THEN \n s.@v_data += (s.type + delimiter + stringify(getvid(s)) + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype) - print_query_http += "END" - print_query_kafka += "END" + print_query_http += """ + {} s.type == "{}" THEN + @@v_batch += (s.type + delimiter + stringify(getvid(s)) + "\\n")\n"""\ + .format("IF" if idx==0 else "ELSE IF", vtype) + print_query_kafka += """ + {} s.type == "{}" THEN + STRING msg = (s.type + delimiter + stringify(getvid(s)) + "\\n"), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, getvid(s)%kafka_topic_partitions, "vertex_" + stringify(getvid(s)), msg), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending data for vertex " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\\n") + END\n""".format("IF" if idx==0 else "ELSE IF", vtype) + print_query_http += "\ + END" + print_query_kafka += "\ + END" query_replace["{VERTEXATTRSHTTP}"] = print_query_http query_replace["{VERTEXATTRSKAFKA}"] = print_query_kafka else: @@ -2921,12 +2937,20 @@ def _install_query(self, force: bool = False) -> str: print_query_http = '@@v_batch += (stringify(getvid(s)) + delimiter + {} + "\\n")'.format( print_attr ) - print_query_kafka = 's.@v_data += (stringify(getvid(s)) + delimiter + {} + "\\n")'.format( - print_attr - ) + print_query_kafka = """ + STRING msg = (stringify(getvid(s)) + delimiter + {} + "\\n"), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, getvid(s)%kafka_topic_partitions, "vertex_" + stringify(getvid(s)), msg), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending data for vertex " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\\n") + END""".format(print_attr) else: print_query_http = '@@v_batch += (stringify(getvid(s)) + "\\n")' - print_query_kafka = 's.@v_data += (stringify(getvid(s)) + "\\n")' + print_query_kafka = """ + STRING msg = (stringify(getvid(s)) + "\\n"), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, getvid(s)%kafka_topic_partitions, "vertex_" + stringify(getvid(s)), msg), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending data for vertex " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\\n") + END""" query_replace["{VERTEXATTRSHTTP}"] = print_query_http query_replace["{VERTEXATTRSKAFKA}"] = print_query_kafka # Install query diff --git a/pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql b/pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql index 70604768..c3fbc930 100644 --- a/pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql +++ b/pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql @@ -4,7 +4,7 @@ CREATE QUERY vertex_loader_{QUERYSUFFIX}( SET v_types, STRING delimiter, STRING kafka_address="", - STRING kafka_topic, + STRING kafka_topic="", INT kafka_topic_partitions=1, STRING kafka_max_size="104857600", INT kafka_timeout=300000, @@ -43,11 +43,7 @@ CREATE QUERY vertex_loader_{QUERYSUFFIX}( */ # If getting all vertices of given types IF input_vertices.size()==0 THEN - # Filter seeds if needed. start = {v_types}; - seeds = SELECT s - FROM start:s - WHERE filter_by is NULL OR s.getAttr(filter_by, "BOOL"); # Generate batches # If using kafka to export @@ -64,14 +60,10 @@ CREATE QUERY vertex_loader_{QUERYSUFFIX}( sasl_kerberos_keytab, sasl_kerberos_principal); res = SELECT s - FROM seeds:s + FROM start:s + WHERE filter_by is NULL OR s.getAttr(filter_by, "BOOL") POST-ACCUM - {VERTEXATTRSKAFKA}, - INT kafka_errcode = write_to_kafka(producer, kafka_topic, getvid(s)%kafka_topic_partitions, "vertex_" + stringify(getvid(s)), s.@v_data), - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending data for vertex " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\n") - END, - s.@v_data = "" + {VERTEXATTRSKAFKA} LIMIT 1; FOREACH i IN RANGE[0, kafka_topic_partitions-1] DO @@ -90,8 +82,9 @@ CREATE QUERY vertex_loader_{QUERYSUFFIX}( ELSE ListAccum @@v_batch; res = SELECT s - FROM seeds:s - ACCUM + FROM start:s + WHERE filter_by is NULL OR s.getAttr(filter_by, "BOOL") + POST-ACCUM {VERTEXATTRSHTTP}; FOREACH i IN @@v_batch DO @@ -102,10 +95,11 @@ CREATE QUERY vertex_loader_{QUERYSUFFIX}( ELSE ListAccum @@v_batch; start = input_vertices; - seeds = SELECT s - FROM start:s - POST-ACCUM - {VERTEXATTRSHTTP}; + res = SELECT s + FROM start:s + POST-ACCUM + {VERTEXATTRSHTTP} + LIMIT 1; FOREACH i IN @@v_batch DO PRINT i as data_batch; END; From 7315071e8d3214020809a837aa0db332f7737cda Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Thu, 19 Oct 2023 20:37:40 -0700 Subject: [PATCH 15/36] feat: move shuffle to gsql for speed and add EdgeLoader --- pyTigerGraph/gds/dataloaders.py | 234 +++++++------- .../gds/gsql/dataloaders/edge_loader.gsql | 239 ++++---------- .../gds/gsql/dataloaders/vertex_loader.gsql | 61 ++-- tests/test_gds_EdgeLoader.py | 302 ++++++++++++++---- 4 files changed, 441 insertions(+), 395 deletions(-) diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index 389f773d..d43027de 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -53,7 +53,6 @@ def __init__( graph: "TigerGraphConnection", loader_id: str = None, num_batches: int = 1, - shuffle: bool = False, buffer_size: int = 4, output_format: str = "dataframe", reverse_edge: bool = False, @@ -186,8 +185,7 @@ def __init__( # Queues to store tasks and data self._request_task_q = None self._download_task_q = None - self._read_task_q1 = None - self._read_task_q2 = None + self._read_task_q = None self._data_q = None self._kafka_topic = None self._all_kafka_topics = set() @@ -239,7 +237,6 @@ def __init__( self._iterator = False self.callback_fn = callback_fn self.distributed_query = distributed_query - self.shuffle = shuffle # Kafka consumer and admin self.max_kafka_msg_size = Kafka_max_msg_size self.kafka_address_consumer = ( @@ -554,9 +551,7 @@ def _request_kafka( def _request_graph_rest( tgraph: "TigerGraphConnection", query_name: str, - read_task_q1: Queue, - read_task_q2: Queue, - shuffle: bool = False, + read_task_q: Queue, timeout: int = 600000, payload: dict = {}, ) -> NoReturn: @@ -565,23 +560,15 @@ def _request_graph_rest( query_name, params=payload, timeout=timeout, usePost=True ) # Put raw data into reading queue. - # If shuffle, randomly choose between the two queues. - # Otherwise, put all into queue 1. for i in resp: - if shuffle and random.random>0.5: - read_task_q2.put((i["vertex_batch"], i["edge_batch"])) - else: - read_task_q1.put((i["vertex_batch"], i["edge_batch"])) - read_task_q1.put(None) - read_task_q2.put(None) + read_task_q.put((i["vertex_batch"], i["edge_batch"])) + read_task_q.put(None) @staticmethod def _request_unimode_rest( tgraph: "TigerGraphConnection", query_name: str, - read_task_q1: Queue, - read_task_q2: Queue, - shuffle: bool = False, + read_task_q: Queue, timeout: int = 600000, payload: dict = {}, ) -> NoReturn: @@ -591,15 +578,9 @@ def _request_unimode_rest( query_name, params=payload, timeout=timeout, usePost=True ) # Put raw data into reading queue - # If shuffle, randomly choose between the two queues. - # Otherwise, put all into queue 1. for i in resp: - if shuffle and random.random()>0.5: - read_task_q2.put(i["data_batch"]) - else: - read_task_q1.put(i["data_batch"]) - read_task_q1.put(None) - read_task_q2.put(None) + read_task_q.put(i["data_batch"]) + read_task_q.put(None) @staticmethod def _download_graph_kafka( @@ -646,10 +627,8 @@ def _download_graph_kafka( @staticmethod def _download_unimode_kafka( exit_event: Event, - read_task_q1: Queue, - read_task_q2: Queue, - kafka_consumer: "KafkaConsumer", - shuffle: bool = False + read_task_q: Queue, + kafka_consumer: "KafkaConsumer" ) -> NoReturn: empty = False while (not exit_event.is_set()) and (not empty): @@ -660,14 +639,10 @@ def _download_unimode_kafka( for message in msgs: key = message.key.decode("utf-8") if key == "STOP": - read_task_q1.put(None) - read_task_q2.put(None) + read_task_q.put(None) empty = True else: - if shuffle and random.random()>0.5: - read_task_q2.put(message.value.decode("utf-8")) - else: - read_task_q1.put(message.value.decode("utf-8")) + read_task_q.put(message.value.decode("utf-8")) @staticmethod def _read_graph_data( @@ -849,7 +824,7 @@ def _read_graph_data( @staticmethod def _read_vertex_data( exit_event: Event, - in_q: List[Queue], + in_q: Queue, out_q: Queue, batch_size: int, v_in_feats: Union[list, dict] = [], @@ -861,24 +836,17 @@ def _read_vertex_data( callback_fn: Callable = None ) -> NoReturn: buffer = [] - curr_q_idx = 0 - curr_q = in_q[curr_q_idx] - is_q_empty = [False]*len(in_q) last_batch = False - while (not exit_event.is_set()) and (not all(is_q_empty)): + is_empty = False + while (not exit_event.is_set()) and (not is_empty): try: - raw = curr_q.get(timeout=0.5) + raw = in_q.get(timeout=1) except Empty: - next_q_idx = 1 - curr_q_idx - if not is_q_empty[next_q_idx]: - curr_q_idx = next_q_idx - curr_q = in_q[curr_q_idx] continue if raw is None: - is_q_empty[curr_q_idx] = True - if all(is_q_empty): - if len(buffer) > 0: - last_batch = True + is_empty = True + if len(buffer) > 0: + last_batch = True else: buffer.append(raw) if (len(buffer) < batch_size) and (not last_batch): @@ -918,7 +886,6 @@ def _read_edge_data( in_q: Queue, out_q: Queue, batch_size: int, - shuffle: bool = False, e_in_feats: Union[list, dict] = [], e_out_labels: Union[list, dict] = [], e_extra_feats: Union[list, dict] = [], @@ -928,20 +895,20 @@ def _read_edge_data( callback_fn: Callable = None ) -> NoReturn: buffer = [] - while not exit_event.is_set(): + is_empty = False + last_batch = False + while not exit_event.is_set() and (not is_empty): try: raw = in_q.get(timeout=1) except Empty: continue - # if shuffle the data, 50% chance to save this data point for later - if shuffle and (random.random() < 0.5): - in_q.task_done() - in_q.put(raw) - continue - # Store raw into buffer until there are enough data points for a batch - buffer.append(raw) - in_q.task_done() - if len(buffer) < batch_size: + if raw is None: + is_empty = True + if len(buffer) > 0: + last_batch = True + else: + buffer.append(raw) + if (len(buffer) < batch_size) and (not last_batch): continue try: data = BaseLoader._parse_edge_data( @@ -970,6 +937,7 @@ def _read_edge_data( logger.error("Parameters:\n e_in_feats={}\n e_out_labels={}\n e_extra_feats={}\n e_attr_types={}\n delimiter={}\n".format( e_in_feats, e_out_labels, e_extra_feats, e_attr_types, delimiter)) buffer.clear() + out_q.put(None) @staticmethod def _parse_vertex_data( @@ -1672,10 +1640,8 @@ def _start_request(self, is_graph: bool): target=self._download_unimode_kafka, kwargs=dict( exit_event = self._exit_event, - read_task_q1 = self._read_task_q1, - read_task_q2 = self._read_task_q2, - kafka_consumer = self._kafka_consumer, - shuffle = self.shuffle + read_task_q = self._read_task_q, + kafka_consumer = self._kafka_consumer ), ) self._downloader.start() @@ -1711,9 +1677,7 @@ def _start_request(self, is_graph: bool): kwargs=dict( tgraph = self._graph, query_name = self.query_name, - read_task_q1 = self._read_task_q1, - read_task_q2 = self._read_task_q2, - shuffle = self.shuffle, + read_task_q = self._read_task_q, timeout = self.timeout, payload = self._payload ), @@ -1804,16 +1768,10 @@ def _reset(self, theend=False) -> None: self._download_task_q.get(block=False) except Empty: break - if self._read_task_q1: - while True: - try: - self._read_task_q1.get(block=False) - except Empty: - break - if self._read_task_q2: + if self._read_task_q: while True: try: - self._read_task_q2.get(block=False) + self._read_task_q.get(block=False) except Empty: break if self._data_q: @@ -1828,11 +1786,10 @@ def _reset(self, theend=False) -> None: self._downloader.join() if self._reader: self._reader.join() - del self._request_task_q, self._download_task_q, self._read_task_q1, self._read_task_q2, self._data_q + del self._request_task_q, self._download_task_q, self._read_task_q, self._data_q self._exit_event = None self._requester, self._downloader, self._reader = None, None, None - self._request_task_q, self._download_task_q, self._read_task_q1, self._read_task_q2, self._data_q = ( - None, + self._request_task_q, self._download_task_q, self._read_task_q, self._data_q = ( None, None, None, @@ -2556,19 +2513,18 @@ def __init__( self._etypes = sorted(self._etypes) # Initialize parameters for the query if batch_size: - # If batch_size is given, calculate the number of batches + # batch size takes precedence over number of batches + self.batch_size = batch_size + self.num_batches = None + else: + # If number of batches is given, calculate batch size if filter_by: num_edges = sum(self._graph.getEdgeStats(e_type)[e_type][filter_by if isinstance(filter_by, str) else filter_by[e_type]]["TRUE"] for e_type in self._etypes) else: num_edges = sum(self._graph.getEdgeCount(i) for i in self._etypes) - self.num_batches = math.ceil(num_edges / batch_size) - else: - # Otherwise, take the number of batches as is. + self.batch_size = math.ceil(num_edges / num_batches) self.num_batches = num_batches # Initialize the exporter - if batch_size: - self._payload["batch_size"] = batch_size - self._payload["num_batches"] = self.num_batches if filter_by: self._payload["filter_by"] = filter_by self._payload["shuffle"] = shuffle @@ -2590,31 +2546,65 @@ def _install_query(self, force: bool = False): if isinstance(self.attributes, dict): # Multiple edge types - print_query = "" + print_query_kafka = "" + print_query_http = "" for idx, etype in enumerate(self._etypes): e_attr_names = self.attributes.get(etype, []) e_attr_types = self._e_schema[etype] if e_attr_names: print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) - print_query += '{} e.type == "{}" THEN \n @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", etype, print_attr) + print_query_http += """ + {} e.type == "{}" THEN + @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + "\\n")\n"""\ + .format("IF" if idx==0 else "ELSE IF", etype, print_attr) + print_query_kafka += """ + {} e.type == "{}" THEN + STRING msg = (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + "\\n"), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, (getvid(s)+getvid(t))%kafka_topic_partitions, "edge_" + stringify(getvid(s)) + "_" + stringify(getvid(t)), msg), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending data for edge " + stringify(getvid(s)) + "_" + stringify(getvid(t)) + ": "+ stringify(kafka_errcode) + "\\n") + END\n""".format("IF" if idx==0 else "ELSE IF", etype, print_attr) else: - print_query += '{} e.type == "{}" THEN \n @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", etype) - print_query += "END" - query_replace["{EDGEATTRS}"] = print_query + print_query_http += """ + {} e.type == "{}" THEN + @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + "\\n")\n"""\ + .format("IF" if idx==0 else "ELSE IF", etype) + print_query_kafka += """ + {} e.type == "{}" THEN + STRING msg = (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + "\\n"), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, (getvid(s)+getvid(t))%kafka_topic_partitions, "edge_" + stringify(getvid(s)) + "_" + stringify(getvid(t)), msg), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending data for edge " + stringify(getvid(s)) + "_" + stringify(getvid(t)) + ": "+ stringify(kafka_errcode) + "\\n") + END\n""".format("IF" if idx==0 else "ELSE IF", etype) + print_query_http += "\ + END" + print_query_kafka += "\ + END" else: # Ignore edge types e_attr_names = self.attributes e_attr_types = next(iter(self._e_schema.values())) if e_attr_names: print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) - print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + "\\n")'.format( + print_query_http = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + "\\n")'.format( print_attr ) + print_query_kafka = """ + STRING msg = (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + "\\n"), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, (getvid(s)+getvid(t))%kafka_topic_partitions, "edge_" + stringify(getvid(s)) + "_" + stringify(getvid(t)), msg), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending data for edge " + stringify(getvid(s)) + "_" + stringify(getvid(t)) + ": "+ stringify(kafka_errcode) + "\\n") + END""".format(print_attr) else: - print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + "\\n")' - query_replace["{EDGEATTRS}"] = print_query + print_query_http = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + "\\n")' + print_query_kafka = """ + STRING msg = (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + "\\n"), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, (getvid(s)+getvid(t))%kafka_topic_partitions, "edge_" + stringify(getvid(s)) + "_" + stringify(getvid(t)), msg), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending data for edge " + stringify(getvid(s)) + "_" + stringify(getvid(t)) + ": "+ stringify(kafka_errcode) + "\\n") + END""" + query_replace["{EDGEATTRSHTTP}"] = print_query_http + query_replace["{EDGEATTRSKAFKA}"] = print_query_kafka # Install query query_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), @@ -2626,11 +2616,11 @@ def _install_query(self, force: bool = False): def _start(self) -> None: # Create task and result queues - self._read_task_q = Queue(self.buffer_size * 2) + self._read_task_q = Queue(self.buffer_size) self._data_q = Queue(self.buffer_size) self._exit_event = Event() - self._start_request(False, "edge") + self._start_request(False) # Start reading thread. if not self.is_hetero: @@ -2638,27 +2628,20 @@ def _start(self) -> None: else: e_attr_types = self._e_schema self._reader = Thread( - target=self._read_data, - args=( - self._exit_event, - self._read_task_q, - self._data_q, - "edge", - self.output_format, - [], - [], - [], - {}, - self.attributes, - {} if self.is_hetero else [], - {} if self.is_hetero else [], - e_attr_types, - False, - self.delimiter, - False, - self.is_hetero, - self.callback_fn - ), + target=self._read_edge_data, + kwargs=dict( + exit_event = self._exit_event, + in_q = self._read_task_q, + out_q = self._data_q, + batch_size = self.batch_size, + e_in_feats = self.attributes, + e_out_labels = {} if self.is_hetero else [], + e_extra_feats = {} if self.is_hetero else [], + e_attr_types = e_attr_types, + delimiter = self.delimiter, + is_hetero = self.is_hetero, + callback_fn = self.callback_fn + ) ) self._reader.start() @@ -2804,7 +2787,6 @@ def __init__( graph, loader_id, num_batches, - shuffle, buffer_size, output_format, reverse_edge, @@ -2874,6 +2856,7 @@ def __init__( self.num_batches = num_batches if filter_by: self._payload["filter_by"] = filter_by + self._payload["shuffle"] = shuffle self._payload["delimiter"] = delimiter self._payload["v_types"] = self._vtypes self._payload["input_vertices"] = [] @@ -2926,8 +2909,6 @@ def _install_query(self, force: bool = False) -> str: END" print_query_kafka += "\ END" - query_replace["{VERTEXATTRSHTTP}"] = print_query_http - query_replace["{VERTEXATTRSKAFKA}"] = print_query_kafka else: # Ignore vertex types v_attr_names = self.attributes @@ -2951,8 +2932,8 @@ def _install_query(self, force: bool = False) -> str: IF kafka_errcode!=0 THEN @@kafka_error += ("Error sending data for vertex " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\\n") END""" - query_replace["{VERTEXATTRSHTTP}"] = print_query_http - query_replace["{VERTEXATTRSKAFKA}"] = print_query_kafka + query_replace["{VERTEXATTRSHTTP}"] = print_query_http + query_replace["{VERTEXATTRSKAFKA}"] = print_query_kafka # Install query query_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), @@ -2964,8 +2945,7 @@ def _install_query(self, force: bool = False) -> str: def _start(self) -> None: # Create task and result queues - self._read_task_q1 = Queue(self.buffer_size) - self._read_task_q2 = Queue(self.buffer_size) + self._read_task_q = Queue(self.buffer_size) self._data_q = Queue(self.buffer_size) self._exit_event = Event() @@ -2980,7 +2960,7 @@ def _start(self) -> None: target=self._read_vertex_data, kwargs=dict( exit_event = self._exit_event, - in_q = [self._read_task_q1, self._read_task_q2], + in_q = self._read_task_q, out_q = self._data_q, batch_size = self.batch_size, v_in_feats = self.attributes, diff --git a/pyTigerGraph/gds/gsql/dataloaders/edge_loader.gsql b/pyTigerGraph/gds/gsql/dataloaders/edge_loader.gsql index fc0085f8..6f0ce37e 100644 --- a/pyTigerGraph/gds/gsql/dataloaders/edge_loader.gsql +++ b/pyTigerGraph/gds/gsql/dataloaders/edge_loader.gsql @@ -1,12 +1,11 @@ CREATE QUERY edge_loader_{QUERYSUFFIX}( - INT batch_size, - INT num_batches=1, - BOOL shuffle=FALSE, STRING filter_by, SET e_types, STRING delimiter, + BOOL shuffle=FALSE, + INT num_chunks=2, STRING kafka_address="", - STRING kafka_topic, + STRING kafka_topic="", INT kafka_topic_partitions=1, STRING kafka_max_size="104857600", INT kafka_timeout=300000, @@ -41,198 +40,72 @@ CREATE QUERY edge_loader_{QUERYSUFFIX}( sasl_password : SASL password for Kafka. ssl_ca_location: Path to CA certificate for verifying the Kafka broker key */ - TYPEDEF TUPLE ID_Tuple; - INT num_vertices; - INT kafka_errcode; SumAccum @tmp_id; - SumAccum @@kafka_error; - UINT producer; - MapAccum @@edges_sampled; - SetAccum @valid_v_out; - SetAccum @valid_v_in; - - # Initialize Kafka producer - IF kafka_address != "" THEN - producer = init_kafka_producer( - kafka_address, kafka_max_size, security_protocol, - sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, - ssl_certificate_location, ssl_key_location, ssl_key_password, - ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, - sasl_kerberos_keytab, sasl_kerberos_principal); - END; - # Shuffle vertex ID if needed start = {ANY}; + # Filter seeds if needed + seeds = SELECT s + FROM start:s -(e_types:e)- :t + WHERE filter_by is NULL OR e.getAttr(filter_by, "BOOL") + POST-ACCUM s.@tmp_id = getvid(s) + POST-ACCUM t.@tmp_id = getvid(t); + # Shuffle vertex ID if needed IF shuffle THEN - num_vertices = start.size(); + INT num_vertices = seeds.size(); res = SELECT s - FROM start:s - POST-ACCUM s.@tmp_id = floor(rand()*num_vertices); - ELSE - res = SELECT s - FROM start:s - POST-ACCUM s.@tmp_id = getvid(s); - END; - - SumAccum @@num_edges; - IF filter_by IS NOT NULL THEN - res = SELECT s - FROM start:s -(e_types:e)- :t WHERE e.getAttr(filter_by, "BOOL") - ACCUM - IF e.isDirected() THEN # we divide by two later to correct for undirected edges being counted twice, need to count directed edges twice to get correct count - @@num_edges += 2 - ELSE - @@num_edges += 1 - END; - ELSE - res = SELECT s - FROM start:s -(e_types:e)- :t - ACCUM - IF e.isDirected() THEN # we divide by two later to correct for undirected edges being counted twice, need to count directed edges twice to get correct count - @@num_edges += 2 - ELSE - @@num_edges += 1 - END; - END; - INT batch_s; - IF batch_size IS NULL THEN - batch_s = ceil((@@num_edges/2)/num_batches); - ELSE - batch_s = batch_size; + FROM seeds:s + POST-ACCUM s.@tmp_id = floor(rand()*num_vertices) + LIMIT 1; END; # Generate batches - FOREACH batch_id IN RANGE[0, num_batches-1] DO - SumAccum @@e_batch; - SetAccum @@seeds; - SetAccum @@targets; - HeapAccum (1, tmp_id ASC) @@batch_heap; - @@batch_heap.resize(batch_s); - start = {ANY}; - IF filter_by IS NOT NULL THEN - res = - SELECT s - FROM start:s -(e_types:e)- :t - WHERE e.getAttr(filter_by, "BOOL") - AND - ((e.isDirected() AND ((t.@tmp_id >= s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id)) OR - (t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id)))) - OR - (NOT e.isDirected() AND ((t.@tmp_id >= s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id)) OR - (t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id))) - AND ((s.@tmp_id >= t.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id+s.@tmp_id)) OR - (t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id))))) - ACCUM - IF t.@tmp_id >= s.@tmp_id THEN - @@batch_heap += ID_Tuple(((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id), s, t) - ELSE - @@batch_heap += ID_Tuple(((s.@tmp_id*s.@tmp_id)+t.@tmp_id), s, t) - END; - - FOREACH elem IN @@batch_heap DO - SetAccum @@src; - @@seeds += elem.src; - @@targets += elem.tgt; - @@src += elem.src; - src = {@@src}; - res = SELECT s FROM src:s -(e_types:e)- :t - WHERE t == elem.tgt - ACCUM - s.@valid_v_out += elem.tgt, - t.@valid_v_in += elem.src; - END; - start = {@@seeds}; - res = - SELECT s - FROM start:s -(e_types:e)- :t - WHERE t in @@targets AND s IN t.@valid_v_in AND t IN s.@valid_v_out - ACCUM - {EDGEATTRS}, - IF t.@tmp_id >= s.@tmp_id THEN - @@edges_sampled += (((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id) -> TRUE), - IF NOT e.isDirected() THEN - @@edges_sampled += (((s.@tmp_id*s.@tmp_id)+t.@tmp_id+s.@tmp_id) -> TRUE) - END - ELSE - @@edges_sampled += (((s.@tmp_id*s.@tmp_id)+t.@tmp_id) -> TRUE), - IF NOT e.isDirected() THEN - @@edges_sampled += (((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id) -> TRUE) - END - END - POST-ACCUM - s.@valid_v_in.clear(), s.@valid_v_out.clear() - POST-ACCUM - t.@valid_v_in.clear(), t.@valid_v_out.clear(); - ELSE - res = - SELECT s - FROM start:s -(e_types:e)- :t - WHERE ((e.isDirected() AND ((t.@tmp_id >= s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id)) OR - (t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id)))) - OR - (NOT e.isDirected() AND ((t.@tmp_id >= s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id)) OR - (t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id))) - AND ((s.@tmp_id >= t.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id+s.@tmp_id)) OR - (t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id))))) - ACCUM - IF t.@tmp_id >= s.@tmp_id THEN - @@batch_heap += ID_Tuple(((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id), s, t) - ELSE - @@batch_heap += ID_Tuple(((s.@tmp_id*s.@tmp_id)+t.@tmp_id), s, t) - END; - - FOREACH elem IN @@batch_heap DO - SetAccum @@src; - @@seeds += elem.src; - @@targets += elem.tgt; - @@src += elem.src; - src = {@@src}; - res = SELECT s FROM src:s -(e_types:e)- :t - WHERE t == elem.tgt - ACCUM - s.@valid_v_out += elem.tgt, - t.@valid_v_in += elem.src; - END; - start = {@@seeds}; - res = - SELECT s - FROM start:s -(e_types:e)- :t - WHERE t in @@targets AND s IN t.@valid_v_in AND t IN s.@valid_v_out - ACCUM - {EDGEATTRS}, - IF t.@tmp_id >= s.@tmp_id THEN - @@edges_sampled += (((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id) -> TRUE), - IF NOT e.isDirected() THEN - @@edges_sampled += (((s.@tmp_id*s.@tmp_id)+t.@tmp_id+s.@tmp_id) -> TRUE) - END - ELSE - @@edges_sampled += (((s.@tmp_id*s.@tmp_id)+t.@tmp_id) -> TRUE), - IF NOT e.isDirected() THEN - @@edges_sampled += (((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id) -> TRUE) - END - END - POST-ACCUM - s.@valid_v_in.clear(), s.@valid_v_out.clear() - POST-ACCUM - t.@valid_v_in.clear(), t.@valid_v_out.clear(); + # If using kafka to export + IF kafka_address != "" THEN + SumAccum @@kafka_error; + + # Initialize Kafka producer + UINT producer = init_kafka_producer( + kafka_address, kafka_max_size, security_protocol, + sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, + ssl_certificate_location, ssl_key_location, ssl_key_password, + ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, + sasl_kerberos_keytab, sasl_kerberos_principal); + + FOREACH chunk IN RANGE[0, num_chunks-1] DO + res = SELECT s + FROM seeds:s -(e_types:e)- :t + WHERE (filter_by is NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk) + ACCUM + {EDGEATTRSKAFKA} + LIMIT 1; END; - # Export batch - IF kafka_address != "" THEN - # Write to kafka - kafka_errcode = write_to_kafka(producer, kafka_topic, batch_id%kafka_topic_partitions, "edge_batch_" + stringify(batch_id), @@e_batch); - IF kafka_errcode != 0 THEN - @@kafka_error += ("Error sending edge batch " + stringify(batch_id) + ": "+ stringify(kafka_errcode) + "\n"); + + FOREACH i IN RANGE[0, kafka_topic_partitions-1] DO + INT kafka_errcode = write_to_kafka(producer, kafka_topic, i, "STOP", ""); + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending STOP signal to topic partition " + stringify(i) + ": " + stringify(kafka_errcode) + "\n"); END; - ELSE - # Add to response - PRINT @@e_batch AS edge_batch; END; - END; - IF kafka_address != "" THEN - kafka_errcode = close_kafka_producer(producer, kafka_timeout); - IF kafka_errcode != 0 THEN + + INT kafka_errcode = close_kafka_producer(producer, kafka_timeout); + IF kafka_errcode!=0 THEN @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n"); END; PRINT @@kafka_error as kafkaError; + # Else return as http response + ELSE + FOREACH chunk IN RANGE[0, num_chunks-1] DO + ListAccum @@e_batch; + res = SELECT s + FROM seeds:s -(e_types:e)- :t + WHERE (filter_by is NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk) + ACCUM + {EDGEATTRSHTTP} + LIMIT 1; + + FOREACH i IN @@e_batch DO + PRINT i as data_batch; + END; + END; END; } \ No newline at end of file diff --git a/pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql b/pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql index c3fbc930..25b8df47 100644 --- a/pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql +++ b/pyTigerGraph/gds/gsql/dataloaders/vertex_loader.gsql @@ -3,6 +3,8 @@ CREATE QUERY vertex_loader_{QUERYSUFFIX}( STRING filter_by, SET v_types, STRING delimiter, + BOOL shuffle=FALSE, + INT num_chunks=2, STRING kafka_address="", STRING kafka_topic="", INT kafka_topic_partitions=1, @@ -41,15 +43,33 @@ CREATE QUERY vertex_loader_{QUERYSUFFIX}( sasl_password : SASL password for Kafka. ssl_ca_location: Path to CA certificate for verifying the Kafka broker key. */ + SumAccum @tmp_id; + # If getting all vertices of given types IF input_vertices.size()==0 THEN start = {v_types}; + # Filter seeds if needed + seeds = SELECT s + FROM start:s + WHERE filter_by is NULL OR s.getAttr(filter_by, "BOOL"); + # Shuffle vertex ID if needed + IF shuffle THEN + INT num_vertices = seeds.size(); + res = SELECT s + FROM seeds:s + POST-ACCUM s.@tmp_id = floor(rand()*num_vertices) + LIMIT 1; + ELSE + res = SELECT s + FROM seeds:s + POST-ACCUM s.@tmp_id = getvid(s) + LIMIT 1; + END; - # Generate batches + # Export data # If using kafka to export IF kafka_address != "" THEN SumAccum @@kafka_error; - SumAccum @v_data; # Initialize Kafka producer UINT producer = init_kafka_producer( @@ -59,13 +79,15 @@ CREATE QUERY vertex_loader_{QUERYSUFFIX}( ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, sasl_kerberos_keytab, sasl_kerberos_principal); - res = SELECT s - FROM start:s - WHERE filter_by is NULL OR s.getAttr(filter_by, "BOOL") - POST-ACCUM - {VERTEXATTRSKAFKA} - LIMIT 1; - + FOREACH chunk IN RANGE[0, num_chunks-1] DO + res = SELECT s + FROM seeds:s + WHERE s.@tmp_id % num_chunks == chunk + POST-ACCUM + {VERTEXATTRSKAFKA} + LIMIT 1; + END; + FOREACH i IN RANGE[0, kafka_topic_partitions-1] DO INT kafka_errcode = write_to_kafka(producer, kafka_topic, i, "STOP", ""); IF kafka_errcode!=0 THEN @@ -80,15 +102,18 @@ CREATE QUERY vertex_loader_{QUERYSUFFIX}( PRINT @@kafka_error as kafkaError; # Else return as http response ELSE - ListAccum @@v_batch; - res = SELECT s - FROM start:s - WHERE filter_by is NULL OR s.getAttr(filter_by, "BOOL") - POST-ACCUM - {VERTEXATTRSHTTP}; - - FOREACH i IN @@v_batch DO - PRINT i as data_batch; + FOREACH chunk IN RANGE[0, num_chunks-1] DO + ListAccum @@v_batch; + res = SELECT s + FROM seeds:s + WHERE s.@tmp_id % num_chunks == chunk + POST-ACCUM + {VERTEXATTRSHTTP} + LIMIT 1; + + FOREACH i IN @@v_batch DO + PRINT i as data_batch; + END; END; END; # Else get given vertices. diff --git a/tests/test_gds_EdgeLoader.py b/tests/test_gds_EdgeLoader.py index 5e22671c..eac8075d 100644 --- a/tests/test_gds_EdgeLoader.py +++ b/tests/test_gds_EdgeLoader.py @@ -7,7 +7,7 @@ from pyTigerGraph.gds.utilities import is_query_installed -class TestGDSEdgeLoader(unittest.TestCase): +class TestGDSEdgeLoaderKafka(unittest.TestCase): @classmethod def setUpClass(cls): cls.conn = make_connection(graphname="Cora") @@ -17,44 +17,47 @@ def test_init(self): graph=self.conn, batch_size=1024, shuffle=False, - filter_by=None, - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 11) + self.assertIsNone(loader.num_batches) def test_iterate(self): loader = EdgeLoader( graph=self.conn, batch_size=1024, shuffle=True, - filter_by=None, - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data.head()) self.assertIsInstance(data, DataFrame) + self.assertIn("source", data) + self.assertIn("target", data) num_batches += 1 + self.assertEqual(data.shape[1], 2) + batch_sizes.append(data.shape[0]) self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) def test_whole_edgelist(self): loader = EdgeLoader( graph=self.conn, num_batches=1, - shuffle=True, - filter_by=None, - loader_id=None, - buffer_size=4, + shuffle=False, kafka_address="kafka:9092", ) data = loader.data # print(data) self.assertIsInstance(data, DataFrame) + self.assertIn("source", data) + self.assertIn("target", data) + self.assertEqual(data.shape[0], 10556) + self.assertEqual(data.shape[1], 2) def test_iterate_attr(self): loader = EdgeLoader( @@ -62,19 +65,46 @@ def test_iterate_attr(self): attributes=["time", "is_train"], batch_size=1024, shuffle=True, - filter_by=None, - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data.head()) + self.assertIsInstance(data, DataFrame) + self.assertIn("time", data) + self.assertIn("is_train", data) + num_batches += 1 + self.assertEqual(data.shape[1], 4) + batch_sizes.append(data.shape[0]) + self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) + + def test_iterate_attr_multichar_delimiter(self): + loader = EdgeLoader( + graph=self.conn, + attributes=["time", "is_train"], + batch_size=1024, + shuffle=True, + kafka_address="kafka:9092", + delimiter="|$" + ) + num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data.head()) self.assertIsInstance(data, DataFrame) self.assertIn("time", data) self.assertIn("is_train", data) num_batches += 1 + self.assertEqual(data.shape[1], 4) + batch_sizes.append(data.shape[0]) self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) def test_sasl_plaintext(self): loader = EdgeLoader( @@ -130,6 +160,111 @@ def test_sasl_ssl(self): # TODO: test filter_by +class TestGDSHeteroEdgeLoaderKafka(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.conn = make_connection(graphname="hetero") + + def test_init(self): + loader = EdgeLoader( + graph=self.conn, + batch_size=1024, + shuffle=False, + kafka_address="kafka:9092" + ) + self.assertTrue(is_query_installed(self.conn, loader.query_name)) + self.assertIsNone(loader.num_batches) + + def test_iterate_as_homo(self): + loader = EdgeLoader( + graph=self.conn, + batch_size=1024, + shuffle=False, + kafka_address="kafka:9092" + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data.head()) + self.assertIsInstance(data, DataFrame) + self.assertIn("source", data) + self.assertIn("target", data) + num_batches += 1 + self.assertEqual(data.shape[1], 2) + batch_sizes.append(data.shape[0]) + self.assertEqual(num_batches, 6) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) + + def test_iterate_hetero(self): + loader = EdgeLoader( + graph=self.conn, + attributes={"v0v0": ["is_train", "is_val"], "v2v0": ["is_train", "is_val"]}, + batch_size=200, + shuffle=True, + kafka_address="kafka:9092" + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + batchsize = 0 + if "v0v0" in data: + self.assertIsInstance(data["v0v0"], DataFrame) + self.assertIn("is_val", data["v0v0"]) + self.assertIn("is_train", data["v0v0"]) + batchsize += data["v0v0"].shape[0] + self.assertEqual(data["v0v0"].shape[1], 4) + if "v2v0" in data: + self.assertIsInstance(data["v2v0"], DataFrame) + self.assertIn("is_val", data["v2v0"]) + self.assertIn("is_train", data["v2v0"]) + batchsize += data["v2v0"].shape[0] + self.assertEqual(data["v2v0"].shape[1], 4) + self.assertGreater(len(data), 0) + num_batches += 1 + batch_sizes.append(batchsize) + self.assertEqual(num_batches, 9) + for i in batch_sizes[:-1]: + self.assertEqual(i, 200) + self.assertLessEqual(batch_sizes[-1], 200) + + def test_iterate_hetero_multichar_delimiter(self): + loader = EdgeLoader( + graph=self.conn, + attributes={"v0v0": ["is_train", "is_val"], "v2v0": ["is_train", "is_val"]}, + batch_size=200, + shuffle=True, + delimiter="|$", + kafka_address="kafka:9092" + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + batchsize = 0 + if "v0v0" in data: + self.assertIsInstance(data["v0v0"], DataFrame) + self.assertIn("is_val", data["v0v0"]) + self.assertIn("is_train", data["v0v0"]) + batchsize += data["v0v0"].shape[0] + self.assertEqual(data["v0v0"].shape[1], 4) + if "v2v0" in data: + self.assertIsInstance(data["v2v0"], DataFrame) + self.assertIn("is_val", data["v2v0"]) + self.assertIn("is_train", data["v2v0"]) + batchsize += data["v2v0"].shape[0] + self.assertEqual(data["v2v0"].shape[1], 4) + self.assertGreater(len(data), 0) + num_batches += 1 + batch_sizes.append(batchsize) + self.assertEqual(num_batches, 9) + for i in batch_sizes[:-1]: + self.assertEqual(i, 200) + self.assertLessEqual(batch_sizes[-1], 200) + + class TestGDSEdgeLoaderREST(unittest.TestCase): @classmethod def setUpClass(cls): @@ -140,41 +275,45 @@ def test_init(self): graph=self.conn, batch_size=1024, shuffle=False, - filter_by=None, - loader_id=None, - buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 11) + self.assertIsNone(loader.num_batches) def test_iterate(self): loader = EdgeLoader( graph=self.conn, batch_size=1024, shuffle=True, - filter_by=None, - loader_id=None, - buffer_size=4, ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data.head()) self.assertIsInstance(data, DataFrame) + self.assertIn("source", data) + self.assertIn("target", data) num_batches += 1 + self.assertEqual(data.shape[1], 2) + batch_sizes.append(data.shape[0]) self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) def test_whole_edgelist(self): loader = EdgeLoader( graph=self.conn, num_batches=1, - shuffle=True, - filter_by=None, - loader_id=None, - buffer_size=4, + shuffle=False, ) data = loader.data # print(data) self.assertIsInstance(data, DataFrame) + self.assertIn("source", data) + self.assertIn("target", data) + self.assertEqual(data.shape[0], 10556) + self.assertEqual(data.shape[1], 2) + def test_iterate_attr(self): loader = EdgeLoader( @@ -182,18 +321,21 @@ def test_iterate_attr(self): attributes=["time", "is_train"], batch_size=1024, shuffle=True, - filter_by=None, - loader_id=None, - buffer_size=4, ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data.head()) self.assertIsInstance(data, DataFrame) self.assertIn("time", data) self.assertIn("is_train", data) num_batches += 1 + self.assertEqual(data.shape[1], 4) + batch_sizes.append(data.shape[0]) self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) def test_iterate_attr_multichar_delimiter(self): loader = EdgeLoader( @@ -201,19 +343,22 @@ def test_iterate_attr_multichar_delimiter(self): attributes=["time", "is_train"], batch_size=1024, shuffle=True, - filter_by=None, - loader_id=None, - buffer_size=4, delimiter="|$" ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data.head()) self.assertIsInstance(data, DataFrame) self.assertIn("time", data) self.assertIn("is_train", data) num_batches += 1 + self.assertEqual(data.shape[1], 4) + batch_sizes.append(data.shape[0]) self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) # TODO: test filter_by @@ -228,87 +373,110 @@ def test_init(self): graph=self.conn, batch_size=1024, shuffle=False, - filter_by=None, - loader_id=None, - buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 6) + self.assertIsNone(loader.num_batches) def test_iterate_as_homo(self): loader = EdgeLoader( graph=self.conn, batch_size=1024, shuffle=False, - filter_by=None, - loader_id=None, - buffer_size=4, ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data.head()) self.assertIsInstance(data, DataFrame) + self.assertIn("source", data) + self.assertIn("target", data) num_batches += 1 + self.assertEqual(data.shape[1], 2) + batch_sizes.append(data.shape[0]) self.assertEqual(num_batches, 6) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) def test_iterate_hetero(self): loader = EdgeLoader( graph=self.conn, attributes={"v0v0": ["is_train", "is_val"], "v2v0": ["is_train", "is_val"]}, batch_size=200, - shuffle=True, # Needed to get around VID distribution issues - filter_by=None, - loader_id=None, - buffer_size=4, + shuffle=True, ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) - self.assertEqual(len(data), 2) - self.assertIsInstance(data["v0v0"], DataFrame) - self.assertIn("is_val", data["v0v0"]) - self.assertIn("is_train", data["v0v0"]) - self.assertIsInstance(data["v2v0"], DataFrame) - self.assertIn("is_val", data["v2v0"]) - self.assertIn("is_train", data["v2v0"]) + batchsize = 0 + if "v0v0" in data: + self.assertIsInstance(data["v0v0"], DataFrame) + self.assertIn("is_val", data["v0v0"]) + self.assertIn("is_train", data["v0v0"]) + batchsize += data["v0v0"].shape[0] + self.assertEqual(data["v0v0"].shape[1], 4) + if "v2v0" in data: + self.assertIsInstance(data["v2v0"], DataFrame) + self.assertIn("is_val", data["v2v0"]) + self.assertIn("is_train", data["v2v0"]) + batchsize += data["v2v0"].shape[0] + self.assertEqual(data["v2v0"].shape[1], 4) + self.assertGreater(len(data), 0) num_batches += 1 + batch_sizes.append(batchsize) self.assertEqual(num_batches, 9) + for i in batch_sizes[:-1]: + self.assertEqual(i, 200) + self.assertLessEqual(batch_sizes[-1], 200) def test_iterate_hetero_multichar_delimiter(self): loader = EdgeLoader( graph=self.conn, attributes={"v0v0": ["is_train", "is_val"], "v2v0": ["is_train", "is_val"]}, batch_size=200, - shuffle=True, # Needed to get around VID distribution issues - filter_by=None, - loader_id=None, - buffer_size=4, + shuffle=True, delimiter="|$" ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) - if num_batches == 0: - self.assertEqual(data["v0v0"].shape[0]+data["v2v0"].shape[0], 200) - self.assertEqual(len(data), 2) - self.assertIsInstance(data["v0v0"], DataFrame) - self.assertIn("is_val", data["v0v0"]) - self.assertIn("is_train", data["v0v0"]) - self.assertIsInstance(data["v2v0"], DataFrame) - self.assertIn("is_val", data["v2v0"]) - self.assertIn("is_train", data["v2v0"]) + batchsize = 0 + if "v0v0" in data: + self.assertIsInstance(data["v0v0"], DataFrame) + self.assertIn("is_val", data["v0v0"]) + self.assertIn("is_train", data["v0v0"]) + batchsize += data["v0v0"].shape[0] + self.assertEqual(data["v0v0"].shape[1], 4) + if "v2v0" in data: + self.assertIsInstance(data["v2v0"], DataFrame) + self.assertIn("is_val", data["v2v0"]) + self.assertIn("is_train", data["v2v0"]) + batchsize += data["v2v0"].shape[0] + self.assertEqual(data["v2v0"].shape[1], 4) + self.assertGreater(len(data), 0) num_batches += 1 + batch_sizes.append(batchsize) self.assertEqual(num_batches, 9) + for i in batch_sizes[:-1]: + self.assertEqual(i, 200) + self.assertLessEqual(batch_sizes[-1], 200) if __name__ == "__main__": suite = unittest.TestSuite() - suite.addTest(TestGDSEdgeLoader("test_init")) - suite.addTest(TestGDSEdgeLoader("test_iterate")) - suite.addTest(TestGDSEdgeLoader("test_whole_edgelist")) - suite.addTest(TestGDSEdgeLoader("test_iterate_attr")) + suite.addTest(TestGDSEdgeLoaderKafka("test_init")) + suite.addTest(TestGDSEdgeLoaderKafka("test_iterate")) + suite.addTest(TestGDSEdgeLoaderKafka("test_whole_edgelist")) + suite.addTest(TestGDSEdgeLoaderKafka("test_iterate_attr")) + suite.addTest(TestGDSEdgeLoaderKafka("test_iterate_attr_multichar_delimiter")) # suite.addTest(TestGDSEdgeLoader("test_sasl_plaintext")) # suite.addTest(TestGDSEdgeLoader("test_sasl_ssl")) + suite.addTest(TestGDSHeteroEdgeLoaderKafka("test_init")) + suite.addTest(TestGDSHeteroEdgeLoaderKafka("test_iterate_as_homo")) + suite.addTest(TestGDSHeteroEdgeLoaderKafka("test_iterate_hetero")) + suite.addTest(TestGDSHeteroEdgeLoaderKafka("test_iterate_hetero_multichar_delimiter")) suite.addTest(TestGDSEdgeLoaderREST("test_init")) suite.addTest(TestGDSEdgeLoaderREST("test_iterate")) suite.addTest(TestGDSEdgeLoaderREST("test_whole_edgelist")) From d8b816bf4f96f2f04b9d74a66811fa93ae52dfba Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Mon, 23 Oct 2023 22:26:15 -0700 Subject: [PATCH 16/36] feat(GraphLoader): update gsql and add subquery --- .../gds/gsql/dataloaders/graph_loader.gsql | 136 ++++++++---------- .../gsql/dataloaders/graph_loader_sub.gsql | 11 ++ 2 files changed, 74 insertions(+), 73 deletions(-) create mode 100644 pyTigerGraph/gds/gsql/dataloaders/graph_loader_sub.gsql diff --git a/pyTigerGraph/gds/gsql/dataloaders/graph_loader.gsql b/pyTigerGraph/gds/gsql/dataloaders/graph_loader.gsql index b4035b90..5e0d5f61 100644 --- a/pyTigerGraph/gds/gsql/dataloaders/graph_loader.gsql +++ b/pyTigerGraph/gds/gsql/dataloaders/graph_loader.gsql @@ -1,10 +1,10 @@ CREATE QUERY graph_loader_{QUERYSUFFIX}( - INT num_batches=1, - BOOL shuffle=FALSE, STRING filter_by, SET v_types, SET e_types, STRING delimiter, + BOOL shuffle=FALSE, + INT num_chunks=2, STRING kafka_address="", STRING kafka_topic, INT kafka_topic_partitions=1, @@ -40,93 +40,83 @@ CREATE QUERY graph_loader_{QUERYSUFFIX}( sasl_password : SASL password for Kafka. ssl_ca_location: Path to CA certificate for verifying the Kafka broker key. */ - INT num_vertices; - INT kafka_errcode; SumAccum @tmp_id; - SumAccum @@kafka_error; - UINT producer; - # Initialize Kafka producer - IF kafka_address != "" THEN - producer = init_kafka_producer( - kafka_address, kafka_max_size, security_protocol, - sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, - ssl_certificate_location, ssl_key_location, ssl_key_password, - ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, - sasl_kerberos_keytab, sasl_kerberos_principal); - END; - - # Shuffle vertex ID if needed start = {v_types}; + # Filter seeds if needed + seeds = SELECT s + FROM start:s -(e_types:e)- v_types:t + WHERE filter_by is NULL OR e.getAttr(filter_by, "BOOL") + POST-ACCUM s.@tmp_id = getvid(s) + POST-ACCUM t.@tmp_id = getvid(t); + # Shuffle vertex ID if needed IF shuffle THEN - IF filter_by IS NOT NULL THEN - start = SELECT s FROM start:s WHERE s.getAttr(filter_by, "BOOL"); - END; - num_vertices = start.size(); + INT num_vertices = seeds.size(); res = SELECT s - FROM start:s - POST-ACCUM s.@tmp_id = floor(rand()*num_vertices); - ELSE - res = SELECT s - FROM start:s - POST-ACCUM s.@tmp_id = getvid(s); + FROM seeds:s + POST-ACCUM s.@tmp_id = floor(rand()*num_vertices) + LIMIT 1; END; # Generate batches - FOREACH batch_id IN RANGE[0, num_batches-1] DO - SetAccum @@vertices; - SumAccum @@e_batch; - SumAccum @@v_batch; - - start = {v_types}; - IF filter_by IS NOT NULL THEN - res = - SELECT s - FROM start:s -(e_types:e)- v_types:t - WHERE e.getAttr(filter_by, "BOOL") and ((s.@tmp_id+t.@tmp_id)*(s.@tmp_id+t.@tmp_id+1)/2+t.@tmp_id)%num_batches==batch_id - ACCUM - {EDGEATTRS}, - @@vertices += s, - @@vertices += t; - ELSE - res = - SELECT s - FROM start:s -(e_types:e)- v_types:t - WHERE ((s.@tmp_id+t.@tmp_id)*(s.@tmp_id+t.@tmp_id+1)/2+t.@tmp_id)%num_batches==batch_id - ACCUM - {EDGEATTRS}, - @@vertices += s, - @@vertices += t; + # If using kafka to export + IF kafka_address != "" THEN + SumAccum @@kafka_error; + + # Initialize Kafka producer + UINT producer = init_kafka_producer( + kafka_address, kafka_max_size, security_protocol, + sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, + ssl_certificate_location, ssl_key_location, ssl_key_password, + ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, + sasl_kerberos_keytab, sasl_kerberos_principal); + + FOREACH chunk IN RANGE[0, num_chunks-1] DO + res = SELECT s + FROM seeds:s -(e_types:e)- v_types:t + WHERE (filter_by is NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk) + ACCUM + STRING s_msg = graph_loader_sub_{QUERYSUFFIX}(s, delimiter), + STRING t_msg = graph_loader_sub_{QUERYSUFFIX}(t, delimiter), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, (getvid(s)+getvid(t))%kafka_topic_partitions, "vertex_batch_" + stringify(getvid(s))+e.type+stringify(getvid(t)), s_msg+t_msg), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending vertex data for edge " + stringify(getvid(s))+e.type+stringify(getvid(t)) + ": "+ stringify(kafka_errcode) + "\\n") + END, + {EDGEATTRSKAFKA} + LIMIT 1; END; - - # Get vertex attributes - v_in_batch = @@vertices; - attr = - SELECT s - FROM v_in_batch:s - POST-ACCUM - {VERTEXATTRS}; - IF kafka_address != "" THEN - # Write to kafka - kafka_errcode = write_to_kafka(producer, kafka_topic, batch_id%kafka_topic_partitions, "vertex_batch_" + stringify(batch_id), @@v_batch); - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending vertex batch " + stringify(batch_id) + ": "+ stringify(kafka_errcode) + "\n"); - END; - kafka_errcode = write_to_kafka(producer, kafka_topic, batch_id%kafka_topic_partitions, "edge_batch_" + stringify(batch_id), @@e_batch); + FOREACH i IN RANGE[0, kafka_topic_partitions-1] DO + INT kafka_errcode = write_to_kafka(producer, kafka_topic, i, "STOP", ""); IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending edge batch " + stringify(batch_id) + ": "+ stringify(kafka_errcode) + "\n"); + @@kafka_error += ("Error sending STOP signal to topic partition " + stringify(i) + ": " + stringify(kafka_errcode) + "\n"); END; - ELSE - # Add to response - PRINT @@v_batch AS vertex_batch, @@e_batch AS edge_batch; END; - END; - IF kafka_address != "" THEN - kafka_errcode = close_kafka_producer(producer, kafka_timeout); + + INT kafka_errcode = close_kafka_producer(producer, kafka_timeout); IF kafka_errcode!=0 THEN @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n"); END; PRINT @@kafka_error as kafkaError; + # Else return as http response + ELSE + FOREACH chunk IN RANGE[0, num_chunks-1] DO + MapAccum @@v_batch; + MapAccum @@e_batch; + + res = SELECT s + FROM seeds:s -(e_types:e)- v_types:t + WHERE (filter_by is NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk) + ACCUM + STRING s_msg = graph_loader_sub_{QUERYSUFFIX}(s, delimiter), + STRING t_msg = graph_loader_sub_{QUERYSUFFIX}(t, delimiter), + @@v_batch += (stringify(getvid(s))+e.type+stringify(getvid(t)) -> s_msg+t_msg), + {EDGEATTRSHTTP} + LIMIT 1; + + FOREACH (k,v) IN @@v_batch DO + PRINT v as vertex_batch, @@e_batch.get(k) as edge_batch; + END; + END; END; } \ No newline at end of file diff --git a/pyTigerGraph/gds/gsql/dataloaders/graph_loader_sub.gsql b/pyTigerGraph/gds/gsql/dataloaders/graph_loader_sub.gsql new file mode 100644 index 00000000..11eb6993 --- /dev/null +++ b/pyTigerGraph/gds/gsql/dataloaders/graph_loader_sub.gsql @@ -0,0 +1,11 @@ +CREATE QUERY graph_loader_sub_{QUERYSUFFIX} (VERTEX v, STRING delimiter) +RETURNS (STRING) +{ + STRING ret; + start = {v}; + res = SELECT s + FROM start:s + POST-ACCUM + {VERTEXATTRS}; + RETURN ret; +} From 7b2970ceae7177cce8a426c2146cab26271d90d2 Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Mon, 23 Oct 2023 22:33:32 -0700 Subject: [PATCH 17/36] feat(dataloaders): update for GraphLoader --- pyTigerGraph/gds/dataloaders.py | 276 ++++++++++++++++++-------------- 1 file changed, 152 insertions(+), 124 deletions(-) diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index d43027de..44ada0aa 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -39,7 +39,7 @@ import pandas as pd from ..pyTigerGraphException import TigerGraphException -from .utilities import install_query_file, random_string, add_attribute +from .utilities import install_query_file, random_string, add_attribute, install_query_files __all__ = ["VertexLoader", "EdgeLoader", "NeighborLoader", "GraphLoader", "EdgeNeighborLoader", "NodePieceLoader", "HGTLoader"] @@ -183,8 +183,6 @@ def __init__( self._downloader = None self._reader = None # Queues to store tasks and data - self._request_task_q = None - self._download_task_q = None self._read_task_q = None self._data_q = None self._kafka_topic = None @@ -586,28 +584,27 @@ def _request_unimode_rest( def _download_graph_kafka( exit_event: Event, read_task_q: Queue, - kafka_consumer: "KafkaConsumer", - max_wait_time: int = 300 + kafka_consumer: "KafkaConsumer" ) -> NoReturn: - delivered_batch = 0 + empty = False buffer = {} - wait_time = 0 - while (not exit_event.is_set()) and (wait_time < max_wait_time): + while (not exit_event.is_set()) and (not empty): resp = kafka_consumer.poll(1000) if not resp: - wait_time += 1 continue - wait_time = 0 for msgs in resp.values(): for message in msgs: key = message.key.decode("utf-8") + if key == "STOP": + read_task_q.put(None) + empty = True + continue if key.startswith("vertex"): companion_key = key.replace("vertex", "edge") if companion_key in buffer: read_task_q.put((message.value.decode("utf-8"), buffer[companion_key])) del buffer[companion_key] - delivered_batch += 1 else: buffer[key] = message.value.decode("utf-8") elif key.startswith("edge"): @@ -616,11 +613,10 @@ def _download_graph_kafka( read_task_q.put((buffer[companion_key], message.value.decode("utf-8"))) del buffer[companion_key] - delivered_batch += 1 else: buffer[key] = message.value.decode("utf-8") else: - raise ValueError( + warnings.warn( "Unrecognized key {} for messages in kafka".format(key) ) @@ -650,7 +646,6 @@ def _read_graph_data( in_q: Queue, out_q: Queue, batch_size: int, - shuffle: bool = False, out_format: str = "dataframe", v_in_feats: Union[list, dict] = [], v_out_labels: Union[list, dict] = [], @@ -708,25 +703,25 @@ def _read_graph_data( "Spektral is not installed. Please install it to use spektral output." ) # Get raw data from queue and parse - vertex_buffer = [] - edge_buffer = [] + vertex_buffer = set() + edge_buffer = set() buffer_size = 0 - while not exit_event.is_set(): + is_empty = False + last_batch = False + while (not exit_event.is_set()) and (not is_empty): try: raw = in_q.get(timeout=1) except Empty: continue - # if shuffle the data, 50% chance to save this data point for later - if shuffle and (random.random() < 0.5): - in_q.task_done() - in_q.put(raw) - continue - # Store raw into buffer until there are enough data points for a batch - vertex_buffer.extend(raw[0].splitlines()) - edge_buffer.extend(raw[1].splitlines()) - buffer_size += 1 - in_q.task_done() - if buffer_size < batch_size: + if raw is None: + is_empty = True + if buffer_size > 0: + last_batch = True + else: + vertex_buffer.update(raw[0].splitlines()) + edge_buffer.update(raw[1].splitlines()) + buffer_size += 1 + if (buffer_size < batch_size) and (not last_batch): continue try: data = BaseLoader._parse_graph_data_to_df( @@ -820,6 +815,7 @@ def _read_graph_data( vertex_buffer.clear() edge_buffer.clear() buffer_size = 0 + out_q.put(None) @staticmethod def _read_vertex_data( @@ -897,7 +893,7 @@ def _read_edge_data( buffer = [] is_empty = False last_batch = False - while not exit_event.is_set() and (not is_empty): + while (not exit_event.is_set()) and (not is_empty): try: raw = in_q.get(timeout=1) except Empty: @@ -1632,7 +1628,6 @@ def _start_request(self, is_graph: bool): exit_event = self._exit_event, read_task_q = self._read_task_q, kafka_consumer = self._kafka_consumer, - max_wait_time = self.timeout ), ) else: @@ -1687,7 +1682,6 @@ def _start_request(self, is_graph: bool): def _start(self) -> None: # This is a template. Implement your own logics here. # Create task and result queues - self._request_task_q = Queue() self._read_task_q = Queue() self._data_q = Queue(self._buffer_size) self._exit_event = Event() @@ -1753,48 +1747,38 @@ def data(self) -> Any: return self def _reset(self, theend=False) -> None: - logger.debug("Resetting the data loader") + logger.debug("Resetting data loader") if self._exit_event: self._exit_event.set() - if self._request_task_q: - while True: - try: - self._request_task_q.get(block=False) - except Empty: - break - if self._download_task_q: - while True: - try: - self._download_task_q.get(block=False) - except Empty: - break + logger.debug("Set exit event") if self._read_task_q: while True: try: - self._read_task_q.get(block=False) + self._read_task_q.get(timeout=1) except Empty: break + logger.debug("Emptied read task queue") if self._data_q: while True: try: - self._data_q.get(block=False) + self._data_q.get(timeout=1) except Empty: break + logger.debug("Emptied data queue") if self._requester: self._requester.join() + logger.debug("Stopped requester thread") if self._downloader: self._downloader.join() + logger.debug("Stopped downloader thread") if self._reader: self._reader.join() - del self._request_task_q, self._download_task_q, self._read_task_q, self._data_q + logger.debug("Stopped reader thread") + del self._read_task_q, self._data_q self._exit_event = None self._requester, self._downloader, self._reader = None, None, None - self._request_task_q, self._download_task_q, self._read_task_q, self._data_q = ( - None, - None, - None, - None - ) + self._read_task_q, self._data_q = None, None + logger.debug("Deleted all queues and threads") if theend: if self._kafka_topic and self._kafka_consumer: self._kafka_consumer.unsubscribe() @@ -1806,6 +1790,7 @@ def _reset(self, theend=False) -> None: raise TigerGraphException( "Failed to delete topic {}".format(del_res["topic"]) ) + logger.debug("Finished with Kafka. Reached the end.") else: if self.delete_epoch_topic and self._kafka_admin: if self._kafka_topic and self._kafka_consumer: @@ -1817,7 +1802,8 @@ def _reset(self, theend=False) -> None: "Failed to delete topic {}".format(del_res["topic"]) ) self._kafka_topic = None - logger.debug("Successfully reset the data loader") + logger.debug("Finished with Kafka") + logger.debug("Reset data loader successfully") def _generate_attribute_string(self, schema_type, attr_names, attr_types) -> str: if schema_type.lower() == "vertex": @@ -2519,9 +2505,21 @@ def __init__( else: # If number of batches is given, calculate batch size if filter_by: - num_edges = sum(self._graph.getEdgeStats(e_type)[e_type][filter_by if isinstance(filter_by, str) else filter_by[e_type]]["TRUE"] for e_type in self._etypes) + num_edges = 0 + for e_type in self._etypes: + tmp = self._graph.getEdgeStats(e_type)[e_type][filter_by if isinstance(filter_by, str) else filter_by[e_type]]["TRUE"] + if self._e_schema[e_type]["IsDirected"]: + num_edges += tmp + else: + num_edges += 2*tmp else: - num_edges = sum(self._graph.getEdgeCount(i) for i in self._etypes) + num_edges = 0 + for e_type in self._etypes: + tmp = self._graph.getEdgeCount(e_type) + if self._e_schema[e_type]["IsDirected"]: + num_edges += tmp + else: + num_edges += 2*tmp self.batch_size = math.ceil(num_edges / num_batches) self.num_batches = num_batches # Initialize the exporter @@ -3194,24 +3192,35 @@ def __init__( self._etypes = sorted(self._etypes) # Initialize parameters for the query if batch_size: - # If batch_size is given, calculate the number of batches + # batch size takes precedence over number of batches + self.batch_size = batch_size + self.num_batches = None + else: + # If number of batches is given, calculate batch size if filter_by: - # TODO: get edge count with filter - raise NotImplementedError + num_edges = 0 + for e_type in self._etypes: + tmp = self._graph.getEdgeStats(e_type)[e_type][filter_by if isinstance(filter_by, str) else filter_by[e_type]]["TRUE"] + if self._e_schema[e_type]["IsDirected"]: + num_edges += tmp + else: + num_edges += 2*tmp else: - num_edges = sum(self._graph.getEdgeCount(i) for i in self._etypes) - self.num_batches = math.ceil(num_edges / batch_size) - else: - # Otherwise, take the number of batches as is. + num_edges = 0 + for e_type in self._etypes: + tmp = self._graph.getEdgeCount(e_type) + if self._e_schema[e_type]["IsDirected"]: + num_edges += tmp + else: + num_edges += 2*tmp + self.batch_size = math.ceil(num_edges / num_batches) self.num_batches = num_batches - self._payload["num_batches"] = self.num_batches if filter_by: self._payload["filter_by"] = filter_by self._payload["shuffle"] = shuffle self._payload["v_types"] = self._vtypes self._payload["e_types"] = self._etypes self._payload["delimiter"] = self.delimiter - self._payload["num_heap_inserts"] = self.num_heap_inserts # Output self.add_self_loop = add_self_loop # Install query @@ -3232,9 +3241,11 @@ def _install_query(self, force: bool = False) -> str: md5.update(json.dumps(query_suffix).encode()) query_replace = {"{QUERYSUFFIX}": md5.hexdigest()} + print_vertex_attr = "" + print_edge_http = "" + print_edge_kafka = "" if isinstance(self.v_in_feats, dict): # Multiple vertex types - print_query = "" for idx, vtype in enumerate(self._vtypes): v_attr_names = ( self.v_in_feats.get(vtype, []) @@ -3242,17 +3253,16 @@ def _install_query(self, force: bool = False) -> str: + self.v_extra_feats.get(vtype, []) ) v_attr_types = self._v_schema[vtype] - if v_attr_names: - print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + {} + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype, print_attr) - else: - print_query += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype) - print_query += "END" - query_replace["{VERTEXATTRS}"] = print_query + print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) + print_vertex_attr += """ + {} s.type == "{}" THEN + ret = (s.type + delimiter + stringify(getvid(s)) {}+ "\\n")"""\ + .format("IF" if idx==0 else "ELSE IF", + vtype, + "+ delimiter + {}".format(print_attr) if v_attr_names else "") + print_vertex_attr += """ + END""" # Multiple edge types - print_query = "" for idx, etype in enumerate(self._etypes): e_attr_names = ( self.e_in_feats.get(etype, []) @@ -3260,38 +3270,51 @@ def _install_query(self, force: bool = False) -> str: + self.e_extra_feats.get(etype, []) ) e_attr_types = self._e_schema[etype] - if e_attr_names: - print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) - print_query += '{} e.type == "{}" THEN \n @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", etype, print_attr) - else: - print_query += '{} e.type == "{}" THEN \n @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", etype) - print_query += "END" - query_replace["{EDGEATTRS}"] = print_query + print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) + print_edge_http += """ + {} e.type == "{}" THEN + STRING e_msg = (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ "\\n"), + @@e_batch += (stringify(getvid(s))+e.type+stringify(getvid(t)) -> e_msg)"""\ + .format("IF" if idx==0 else "ELSE IF", etype, + "+ delimiter + " + print_attr if e_attr_names else "") + print_edge_kafka += """ + {} e.type == "{}" THEN + STRING e_msg = (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ "\\n"), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, (getvid(s)+getvid(t))%kafka_topic_partitions, "edge_batch_" + stringify(getvid(s))+e.type+stringify(getvid(t)), e_msg), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending edge data for edge " + stringify(getvid(s))+e.type+stringify(getvid(t)) + ": "+ stringify(kafka_errcode) + "\\n") + END""".format("IF" if idx==0 else "ELSE IF", etype, + "+ delimiter + " + print_attr if e_attr_names else "") + print_edge_http += """ + END""" + print_edge_kafka += """ + END""" else: # Ignore vertex types v_attr_names = self.v_in_feats + self.v_out_labels + self.v_extra_feats v_attr_types = next(iter(self._v_schema.values())) - if v_attr_names: - print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query = '@@v_batch += (stringify(getvid(s)) + delimiter + {} + "\\n")'.format( - print_attr - ) - else: - print_query = '@@v_batch += (stringify(getvid(s)) + "\\n")' - query_replace["{VERTEXATTRS}"] = print_query + print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) + print_vertex_attr += """ + ret = (stringify(getvid(s)) {}+ "\\n")"""\ + .format("+ delimiter + " + print_attr if v_attr_names else "") # Ignore edge types e_attr_names = self.e_in_feats + self.e_out_labels + self.e_extra_feats e_attr_types = next(iter(self._e_schema.values())) - if e_attr_names: - print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) - print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + "\\n")'.format( - print_attr - ) - else: - print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + "\\n")' - query_replace["{EDGEATTRS}"] = print_query + print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) + print_edge_http += """ + STRING e_msg = (stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ "\\n"), + @@e_batch += (stringify(getvid(s))+e.type+stringify(getvid(t)) -> e_msg)"""\ + .format("+ delimiter + " + print_attr if e_attr_names else "") + print_edge_kafka += """ + STRING e_msg = (stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ "\\n"), + INT kafka_errcode2 = write_to_kafka(producer, kafka_topic, (getvid(s)+getvid(t))%kafka_topic_partitions, "edge_batch_" + stringify(getvid(s))+e.type+stringify(getvid(t)), e_msg), + IF kafka_errcode2!=0 THEN + @@kafka_error += ("Error sending edge data for edge " + stringify(getvid(s))+e.type+stringify(getvid(t)) + ": "+ stringify(kafka_errcode2) + "\\n") + END""".format("+ delimiter + " + print_attr if e_attr_names else "") + query_replace["{VERTEXATTRS}"] = print_vertex_attr + query_replace["{EDGEATTRSKAFKA}"] = print_edge_kafka + query_replace["{EDGEATTRSHTTP}"] = print_edge_http + # Install query query_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), @@ -3299,15 +3322,21 @@ def _install_query(self, force: bool = False) -> str: "dataloaders", "graph_loader.gsql", ) - return install_query_file(self._graph, query_path, query_replace, force=force, distributed=self.distributed_query) + sub_query_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "gsql", + "dataloaders", + "graph_loader_sub.gsql", + ) + return install_query_files(self._graph, [sub_query_path, query_path], query_replace, force=force, distributed=[False, self.distributed_query]) def _start(self) -> None: # Create task and result queues - self._read_task_q = Queue(self.buffer_size * 2) + self._read_task_q = Queue(self.buffer_size) self._data_q = Queue(self.buffer_size) self._exit_event = Event() - self._start_request(True, "both") + self._start_request(True) # Start reading thread. if not self.is_hetero: @@ -3317,26 +3346,25 @@ def _start(self) -> None: v_attr_types = self._v_schema e_attr_types = self._e_schema self._reader = Thread( - target=self._read_data, - args=( - self._exit_event, - self._read_task_q, - self._data_q, - "graph", - self.output_format, - self.v_in_feats, - self.v_out_labels, - self.v_extra_feats, - v_attr_types, - self.e_in_feats, - self.e_out_labels, - self.e_extra_feats, - e_attr_types, - self.add_self_loop, - self.delimiter, - True, - self.is_hetero, - self.callback_fn + target=self._read_graph_data, + kwargs=dict( + exit_event = self._exit_event, + in_q = self._read_task_q, + out_q = self._data_q, + batch_size = self.batch_size, + out_format = self.output_format, + v_in_feats = self.v_in_feats, + v_out_labels = self.v_out_labels, + v_extra_feats = self.v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = self.e_in_feats, + e_out_labels = self.e_out_labels, + e_extra_feats = self.e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = self.add_self_loop, + delimiter = self.delimiter, + is_hetero = self.is_hetero, + callback_fn = self.callback_fn ), ) self._reader.start() From 7f5f584e3ca31b1bfbaf22868e11f4a57aa62056 Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Mon, 23 Oct 2023 22:33:58 -0700 Subject: [PATCH 18/36] feat(utilities): add function to install multiple files --- pyTigerGraph/gds/utilities.py | 71 ++++++++++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/pyTigerGraph/gds/utilities.py b/pyTigerGraph/gds/utilities.py index 5089f7ec..66d49c28 100644 --- a/pyTigerGraph/gds/utilities.py +++ b/pyTigerGraph/gds/utilities.py @@ -7,7 +7,7 @@ import re import string from os.path import join as pjoin -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Union, List from urllib.parse import urlparse if TYPE_CHECKING: @@ -166,6 +166,75 @@ def install_query_file( return query_name +def install_query_files( + conn: "TigerGraphConnection", + file_paths: List[str], + replace: dict = None, + distributed: List[bool] = [], + force: bool = False, +) -> str: + queries_to_install = [] + last_query = "" + for idx, file_path in enumerate(file_paths): + # Read the first line of the file to get query name. The first line should be + # something like CREATE QUERY query_name (... + with open(file_path) as infile: + firstline = infile.readline() + try: + query_name = re.search(r"QUERY (.+?)\(", firstline).group(1).strip() + except: + raise ValueError( + "Cannot parse the query file. It should start with CREATE QUERY ... " + ) + # If a suffix is to be added to query name + if replace and ("{QUERYSUFFIX}" in replace): + query_name = query_name.replace("{QUERYSUFFIX}", replace["{QUERYSUFFIX}"]) + last_query = query_name + # If query is already installed, skip unless force install. + is_installed, is_enabled = is_query_installed(conn, query_name, return_status=True) + if is_installed: + if force or (not is_enabled): + query = "USE GRAPH {}\nDROP QUERY {}\n".format(conn.graphname, query_name) + resp = conn.gsql(query) + if "Successfully dropped queries" not in resp: + raise ConnectionError(resp) + else: + continue + # Otherwise, install the query from file + with open(file_path) as infile: + query = infile.read() + # Replace placeholders with actual content if given + if replace: + for placeholder in replace: + query = query.replace(placeholder, replace[placeholder]) + if distributed and distributed[idx]: + query = query.replace("CREATE QUERY", "CREATE DISTRIBUTED QUERY") + logger.debug(query) + query = ( + "USE GRAPH {}\n".format(conn.graphname) + + query + + "\n" + ) + resp = conn.gsql(query) + if "Successfully created queries" not in resp: + raise ConnectionError(resp) + queries_to_install.append(query_name) + if queries_to_install: + query = ( + "USE GRAPH {}\n".format(conn.graphname) + + "Install Query {}\n".format(",".join(queries_to_install)) + ) + print( + "Installing and optimizing queries. It might take a minute or two." + ) + resp = conn.gsql(query) + if "Query installation finished" not in resp: + raise ConnectionError(resp) + else: + print("Query installation finished.") + return last_query + + def add_attribute(conn: "TigerGraphConnection", schema_type:str, attr_type:str = None, attr_name:Union[str, dict] = None, schema_name:list = None, global_change:bool = False): ''' If the current attribute is not already added to the schema, it will create the schema job to do that. From 732023740ce97603042210e78fa8d73fafd5b221 Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Mon, 23 Oct 2023 22:34:24 -0700 Subject: [PATCH 19/36] test(GraphLoaders): update unit tests --- tests/test_gds_GraphLoader.py | 350 +++++++++++++++++++++++++--------- 1 file changed, 260 insertions(+), 90 deletions(-) diff --git a/tests/test_gds_GraphLoader.py b/tests/test_gds_GraphLoader.py index b06d94aa..ef3c2c07 100644 --- a/tests/test_gds_GraphLoader.py +++ b/tests/test_gds_GraphLoader.py @@ -9,7 +9,7 @@ from pyTigerGraph.gds.utilities import is_query_installed -class TestGDSGraphLoader(unittest.TestCase): +class TestGDSGraphLoaderKafka(unittest.TestCase): @classmethod def setUpClass(cls): cls.conn = make_connection(graphname="Cora") @@ -22,15 +22,11 @@ def test_init(self): v_extra_feats=["train_mask", "val_mask", "test_mask"], batch_size=1024, shuffle=True, - filter_by=None, output_format="dataframe", - add_self_loop=False, - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 11) + self.assertIsNone(loader.num_batches) def test_iterate_pyg(self): loader = GraphLoader( @@ -40,14 +36,11 @@ def test_iterate_pyg(self): v_extra_feats=["train_mask", "val_mask", "test_mask"], batch_size=1024, shuffle=True, - filter_by=None, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygData) @@ -56,8 +49,12 @@ def test_iterate_pyg(self): self.assertIn("train_mask", data) self.assertIn("val_mask", data) self.assertIn("test_mask", data) + batch_sizes.append(data["edge_index"].shape[1]) num_batches += 1 self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) def test_iterate_df(self): loader = GraphLoader( @@ -67,25 +64,28 @@ def test_iterate_df(self): v_extra_feats=["train_mask", "val_mask", "test_mask"], batch_size=1024, shuffle=True, - filter_by=None, output_format="dataframe", - add_self_loop=False, - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) num_batches = 0 + batch_sizes = [] for data in loader: - # print(num_batches, data) + # print(num_batches, data, flush=True) self.assertIsInstance(data[0], DataFrame) - self.assertIsInstance(data[1], DataFrame) self.assertIn("x", data[0].columns) self.assertIn("y", data[0].columns) self.assertIn("train_mask", data[0].columns) self.assertIn("val_mask", data[0].columns) self.assertIn("test_mask", data[0].columns) + self.assertIsInstance(data[1], DataFrame) + self.assertIn("source", data[1]) + self.assertIn("target", data[1]) + batch_sizes.append(data[1].shape[0]) num_batches += 1 self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) def test_edge_attr(self): loader = GraphLoader( @@ -97,11 +97,7 @@ def test_edge_attr(self): e_extra_feats=["is_train"], batch_size=1024, shuffle=True, - filter_by=None, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) num_batches = 0 @@ -207,12 +203,9 @@ def test_init(self): shuffle=True, filter_by=None, output_format="dataframe", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 11) + self.assertIsNone(loader.num_batches) def test_iterate_pyg(self): loader = GraphLoader( @@ -222,13 +215,10 @@ def test_iterate_pyg(self): v_extra_feats=["train_mask", "val_mask", "test_mask"], batch_size=1024, shuffle=True, - filter_by=None, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygData) @@ -238,7 +228,11 @@ def test_iterate_pyg(self): self.assertIn("val_mask", data) self.assertIn("test_mask", data) num_batches += 1 + batch_sizes.append(data["edge_index"].shape[1]) self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) def test_iterate_df(self): loader = GraphLoader( @@ -248,13 +242,10 @@ def test_iterate_df(self): v_extra_feats=["train_mask", "val_mask", "test_mask"], batch_size=1024, shuffle=True, - filter_by=None, output_format="dataframe", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data[0], DataFrame) @@ -264,8 +255,14 @@ def test_iterate_df(self): self.assertIn("train_mask", data[0].columns) self.assertIn("val_mask", data[0].columns) self.assertIn("test_mask", data[0].columns) + self.assertIn("source", data[1]) + self.assertIn("target", data[1]) + batch_sizes.append(data[1].shape[0]) num_batches += 1 self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) def test_edge_attr(self): loader = GraphLoader( @@ -277,11 +274,7 @@ def test_edge_attr(self): e_extra_feats=["is_train"], batch_size=1024, shuffle=True, - filter_by=None, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4 ) num_batches = 0 for data in loader: @@ -327,11 +320,7 @@ def test_iterate_spektral(self): v_extra_feats=["train_mask", "val_mask", "test_mask"], batch_size=1024, shuffle=True, - filter_by=None, output_format="spektral", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) num_batches = 0 for data in loader: @@ -360,14 +349,10 @@ def test_init(self): v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, batch_size=1024, shuffle=False, - filter_by=None, output_format="dataframe", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 6) + self.assertIsNone(loader.num_batches) def test_iterate_pyg(self): loader = GraphLoader( @@ -376,26 +361,36 @@ def test_iterate_pyg(self): "v1": ["x"]}, v_out_labels={"v0": ["y"]}, v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, - batch_size=1024, + batch_size=300, shuffle=True, - filter_by=None, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygHeteroData) - self.assertIn("x", data["v0"]) - self.assertIn("y", data["v0"]) - self.assertIn("train_mask", data["v0"]) - self.assertIn("val_mask", data["v0"]) - self.assertIn("test_mask", data["v0"]) - self.assertIn("x", data["v1"]) + self.assertTrue("v0" in data.node_types or "v1" in data.node_types) + if "v0" in data.node_types: + self.assertIn("x", data["v0"]) + self.assertIn("y", data["v0"]) + self.assertIn("train_mask", data["v0"]) + self.assertIn("val_mask", data["v0"]) + self.assertIn("test_mask", data["v0"]) + if "v1" in data.node_types: + self.assertIn("x", data["v1"]) + self.assertTrue(('v0', 'v0v0', 'v0') in data.edge_types or ('v1', 'v1v1', 'v1') in data.edge_types) + batchsize = 0 + if ('v0', 'v0v0', 'v0') in data.edge_types: + batchsize += data["v0", "v0v0", "v0"].edge_index.shape[1] + if ('v1', 'v1v1', 'v1') in data.edge_types: + batchsize += data["v1", "v1v1", "v1"].edge_index.shape[1] + batch_sizes.append(batchsize) num_batches += 1 self.assertEqual(num_batches, 6) + for i in batch_sizes[:-1]: + self.assertEqual(i, 300) + self.assertLessEqual(batch_sizes[-1], 300) def test_iterate_df(self): loader = GraphLoader( @@ -404,29 +399,41 @@ def test_iterate_df(self): "v1": ["x"]}, v_out_labels={"v0": ["y"]}, v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, - batch_size=1024, + batch_size=300, shuffle=False, - filter_by=None, output_format="dataframe", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) - self.assertIsInstance(data[0]["v0"], DataFrame) - self.assertIsInstance(data[0]["v1"], DataFrame) - self.assertIsInstance(data[1]["v0v0"], DataFrame) - self.assertIsInstance(data[1]["v1v1"], DataFrame) - self.assertIn("x", data[0]["v0"].columns) - self.assertIn("y", data[0]["v0"].columns) - self.assertIn("train_mask", data[0]["v0"].columns) - self.assertIn("val_mask", data[0]["v0"].columns) - self.assertIn("test_mask", data[0]["v0"].columns) - self.assertIn("x", data[0]["v1"].columns) + self.assertTrue("v0" in data[0] or "v1" in data[0]) + if "v0" in data[0]: + self.assertIsInstance(data[0]["v0"], DataFrame) + self.assertIn("x", data[0]["v0"].columns) + self.assertIn("y", data[0]["v0"].columns) + self.assertIn("train_mask", data[0]["v0"].columns) + self.assertIn("val_mask", data[0]["v0"].columns) + self.assertIn("test_mask", data[0]["v0"].columns) + if "v1" in data[0]: + self.assertIsInstance(data[0]["v1"], DataFrame) + self.assertIn("x", data[0]["v1"].columns) + self.assertTrue("v0v0" in data[1] or "v1v1" in data[1]) + batchsize = 0 + if "v0v0" in data[1]: + self.assertIsInstance(data[1]["v0v0"], DataFrame) + batchsize += data[1]["v0v0"].shape[0] + self.assertEqual(data[1]["v0v0"].shape[1], 2) + if "v1v1" in data[1]: + self.assertIsInstance(data[1]["v1v1"], DataFrame) + batchsize += data[1]["v1v1"].shape[0] + self.assertEqual(data[1]["v1v1"].shape[1], 2) + batch_sizes.append(batchsize) num_batches += 1 self.assertEqual(num_batches, 6) + for i in batch_sizes[:-1]: + self.assertEqual(i, 300) + self.assertLessEqual(batch_sizes[-1], 300) def test_edge_attr(self): loader = GraphLoader( @@ -437,38 +444,197 @@ def test_edge_attr(self): v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, e_extra_feats={"v0v0": ["is_train", "is_val"], "v1v1": ["is_train", "is_val"]}, + batch_size=300, + shuffle=False, + output_format="PyG", + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + self.assertIsInstance(data, pygHeteroData) + self.assertTrue("v0" in data.node_types or "v1" in data.node_types) + if "v0" in data.node_types: + self.assertIn("x", data["v0"]) + self.assertIn("y", data["v0"]) + self.assertIn("train_mask", data["v0"]) + self.assertIn("val_mask", data["v0"]) + self.assertIn("test_mask", data["v0"]) + if "v1" in data.node_types: + self.assertIn("x", data["v1"]) + self.assertTrue(('v0', 'v0v0', 'v0') in data.edge_types or ('v1', 'v1v1', 'v1') in data.edge_types) + batchsize = 0 + if ('v0', 'v0v0', 'v0') in data.edge_types: + self.assertIn("is_train", data["v0", "v0v0", "v0"]) + self.assertIn("is_val", data["v0", "v0v0", "v0"]) + batchsize += data["v0", "v0v0", "v0"].edge_index.shape[1] + if ('v1', 'v1v1', 'v1') in data.edge_types: + self.assertIn("is_train", data["v1", "v1v1", "v1"]) + self.assertIn("is_val", data["v1", "v1v1", "v1"]) + batchsize += data["v1", "v1v1", "v1"].edge_index.shape[1] + batch_sizes.append(batchsize) + num_batches += 1 + self.assertEqual(num_batches, 6) + for i in batch_sizes[:-1]: + self.assertEqual(i, 300) + self.assertLessEqual(batch_sizes[-1], 300) + + +class TestGDSHeteroGraphLoaderKafka(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.conn = make_connection(graphname="hetero") + + def test_init(self): + loader = GraphLoader( + graph=self.conn, + v_in_feats={"v0": ["x"], + "v1": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, batch_size=1024, shuffle=False, - filter_by=None, + output_format="dataframe", + kafka_address="kafka:9092", + ) + self.assertTrue(is_query_installed(self.conn, loader.query_name)) + self.assertIsNone(loader.num_batches) + + def test_iterate_pyg(self): + loader = GraphLoader( + graph=self.conn, + v_in_feats={"v0": ["x"], + "v1": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + batch_size=300, + shuffle=True, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4 + kafka_address="kafka:9092", + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + self.assertIsInstance(data, pygHeteroData) + self.assertTrue("v0" in data.node_types or "v1" in data.node_types) + if "v0" in data.node_types: + self.assertIn("x", data["v0"]) + self.assertIn("y", data["v0"]) + self.assertIn("train_mask", data["v0"]) + self.assertIn("val_mask", data["v0"]) + self.assertIn("test_mask", data["v0"]) + if "v1" in data.node_types: + self.assertIn("x", data["v1"]) + self.assertTrue(('v0', 'v0v0', 'v0') in data.edge_types or ('v1', 'v1v1', 'v1') in data.edge_types) + batchsize = 0 + if ('v0', 'v0v0', 'v0') in data.edge_types: + batchsize += data["v0", "v0v0", "v0"].edge_index.shape[1] + if ('v1', 'v1v1', 'v1') in data.edge_types: + batchsize += data["v1", "v1v1", "v1"].edge_index.shape[1] + batch_sizes.append(batchsize) + num_batches += 1 + self.assertEqual(num_batches, 6) + for i in batch_sizes[:-1]: + self.assertEqual(i, 300) + self.assertLessEqual(batch_sizes[-1], 300) + + def test_iterate_df(self): + loader = GraphLoader( + graph=self.conn, + v_in_feats={"v0": ["x"], + "v1": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + batch_size=300, + shuffle=False, + output_format="dataframe", + kafka_address="kafka:9092", + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + self.assertTrue("v0" in data[0] or "v1" in data[0]) + if "v0" in data[0]: + self.assertIsInstance(data[0]["v0"], DataFrame) + self.assertIn("x", data[0]["v0"].columns) + self.assertIn("y", data[0]["v0"].columns) + self.assertIn("train_mask", data[0]["v0"].columns) + self.assertIn("val_mask", data[0]["v0"].columns) + self.assertIn("test_mask", data[0]["v0"].columns) + if "v1" in data[0]: + self.assertIsInstance(data[0]["v1"], DataFrame) + self.assertIn("x", data[0]["v1"].columns) + self.assertTrue("v0v0" in data[1] or "v1v1" in data[1]) + batchsize = 0 + if "v0v0" in data[1]: + self.assertIsInstance(data[1]["v0v0"], DataFrame) + batchsize += data[1]["v0v0"].shape[0] + self.assertEqual(data[1]["v0v0"].shape[1], 2) + if "v1v1" in data[1]: + self.assertIsInstance(data[1]["v1v1"], DataFrame) + batchsize += data[1]["v1v1"].shape[0] + self.assertEqual(data[1]["v1v1"].shape[1], 2) + batch_sizes.append(batchsize) + num_batches += 1 + self.assertEqual(num_batches, 6) + for i in batch_sizes[:-1]: + self.assertEqual(i, 300) + self.assertLessEqual(batch_sizes[-1], 300) + + def test_edge_attr(self): + loader = GraphLoader( + graph=self.conn, + v_in_feats={"v0": ["x"], + "v1": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + e_extra_feats={"v0v0": ["is_train", "is_val"], + "v1v1": ["is_train", "is_val"]}, + batch_size=300, + shuffle=False, + output_format="PyG", + kafka_address="kafka:9092", ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygHeteroData) - self.assertIn("x", data["v0"]) - self.assertIn("y", data["v0"]) - self.assertIn("train_mask", data["v0"]) - self.assertIn("val_mask", data["v0"]) - self.assertIn("test_mask", data["v0"]) - self.assertIn("x", data["v1"]) - self.assertIn("is_train", data["v0v0"]) - self.assertIn("is_train", data["v1v1"]) - self.assertIn("is_val", data["v0v0"]) - self.assertIn("is_val", data["v1v1"]) + self.assertTrue("v0" in data.node_types or "v1" in data.node_types) + if "v0" in data.node_types: + self.assertIn("x", data["v0"]) + self.assertIn("y", data["v0"]) + self.assertIn("train_mask", data["v0"]) + self.assertIn("val_mask", data["v0"]) + self.assertIn("test_mask", data["v0"]) + if "v1" in data.node_types: + self.assertIn("x", data["v1"]) + self.assertTrue(('v0', 'v0v0', 'v0') in data.edge_types or ('v1', 'v1v1', 'v1') in data.edge_types) + batchsize = 0 + if ('v0', 'v0v0', 'v0') in data.edge_types: + self.assertIn("is_train", data["v0", "v0v0", "v0"]) + self.assertIn("is_val", data["v0", "v0v0", "v0"]) + batchsize += data["v0", "v0v0", "v0"].edge_index.shape[1] + if ('v1', 'v1v1', 'v1') in data.edge_types: + self.assertIn("is_train", data["v1", "v1v1", "v1"]) + self.assertIn("is_val", data["v1", "v1v1", "v1"]) + batchsize += data["v1", "v1v1", "v1"].edge_index.shape[1] + batch_sizes.append(batchsize) num_batches += 1 - self.assertEqual(num_batches, 2) + self.assertEqual(num_batches, 6) + for i in batch_sizes[:-1]: + self.assertEqual(i, 300) + self.assertLessEqual(batch_sizes[-1], 300) if __name__ == "__main__": suite = unittest.TestSuite() - suite.addTest(TestGDSGraphLoader("test_init")) - suite.addTest(TestGDSGraphLoader("test_iterate_pyg")) - suite.addTest(TestGDSGraphLoader("test_iterate_df")) - suite.addTest(TestGDSGraphLoader("test_edge_attr")) + suite.addTest(TestGDSGraphLoaderKafka("test_init")) + suite.addTest(TestGDSGraphLoaderKafka("test_iterate_pyg")) + suite.addTest(TestGDSGraphLoaderKafka("test_iterate_df")) + suite.addTest(TestGDSGraphLoaderKafka("test_edge_attr")) # suite.addTest(TestGDSGraphLoader("test_sasl_plaintext")) # suite.addTest(TestGDSGraphLoader("test_sasl_ssl")) suite.addTest(TestGDSGraphLoaderREST("test_init")) @@ -480,6 +646,10 @@ def test_edge_attr(self): suite.addTest(TestGDSHeteroGraphLoaderREST("test_iterate_pyg")) suite.addTest(TestGDSHeteroGraphLoaderREST("test_iterate_df")) suite.addTest(TestGDSHeteroGraphLoaderREST("test_edge_attr")) + suite.addTest(TestGDSHeteroGraphLoaderKafka("test_init")) + suite.addTest(TestGDSHeteroGraphLoaderKafka("test_iterate_pyg")) + suite.addTest(TestGDSHeteroGraphLoaderKafka("test_iterate_df")) + suite.addTest(TestGDSHeteroGraphLoaderKafka("test_edge_attr")) runner = unittest.TextTestRunner(verbosity=2, failfast=True) runner.run(suite) From ffe5b0daad97df3de5c0a9dd31fe8c56ff9f1bbf Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Thu, 26 Oct 2023 13:21:00 -0700 Subject: [PATCH 20/36] feat(dataloaders): update NeighborLoader --- pyTigerGraph/gds/dataloaders.py | 334 ++++++--- .../gds/gsql/dataloaders/graph_loader.gsql | 2 +- .../gds/gsql/dataloaders/neighbor_loader.gsql | 254 ++++--- .../gsql/dataloaders/neighbor_loader_sub.gsql | 38 + tests/test_gds_NeighborLoader.py | 670 +++++++++++------- 5 files changed, 837 insertions(+), 461 deletions(-) create mode 100644 pyTigerGraph/gds/gsql/dataloaders/neighbor_loader_sub.gsql diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index 44ada0aa..476742fa 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -228,7 +228,7 @@ def __init__( else: self._kafka_topic_base = self.loader_id + "_topic" self.num_batches = num_batches - self.output_format = output_format + self.output_format = output_format.lower() self.buffer_size = buffer_size self.timeout = timeout self._iterations = 0 @@ -800,7 +800,7 @@ def _read_graph_data( is_hetero = is_hetero, scipy = scipy, spektral = spektral - ) + ) else: raise NotImplementedError if callback_fn: @@ -1222,6 +1222,19 @@ def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, # Reformat as a graph. # Need to have a pair of tables for edges and vertices. vertices, edges = raw + # Dedupe vertices if there is is_seed column as the same vertex might be seed and non-seed at the same time + if not is_hetero: + if "is_seed" in vertices.columns: + seeds = set(vertices[vertices.is_seed.astype(int).astype(bool)]["vid"]) + vertices.loc[vertices.vid.isin(seeds), "is_seed"] = True + vertices.drop_duplicates(subset="vid", inplace=True, ignore_index=True) + else: + for vtype in vertices: + df = vertices[vtype] + if "is_seed" in df.columns: + seeds = set(df[df.is_seed.astype(int).astype(bool)]["vid"]) + df.loc[df.vid.isin(seeds), "is_seed"] = True + df.drop_duplicates(subset="vid", inplace=True, ignore_index=True) edgelist = BaseLoader._get_edgelist(vertices, edges, is_hetero, e_attr_types) if not is_hetero: # Deal with edgelist first @@ -1423,6 +1436,19 @@ def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, # Reformat as a graph. # Need to have a pair of tables for edges and vertices. vertices, edges = raw + # Dedupe vertices if there is is_seed column as the same vertex might be seed and non-seed at the same time + if not is_hetero: + if "is_seed" in vertices.columns: + seeds = set(vertices[vertices.is_seed.astype(int).astype(bool)]["vid"]) + vertices.loc[vertices.vid.isin(seeds), "is_seed"] = True + vertices.drop_duplicates(subset="vid", inplace=True, ignore_index=True) + else: + for vtype in vertices: + df = vertices[vtype] + if "is_seed" in df.columns: + seeds = set(df[df.is_seed.astype(int).astype(bool)]["vid"]) + df.loc[df.vid.isin(seeds), "is_seed"] = True + df.drop_duplicates(subset="vid", inplace=True, ignore_index=True) edgelist = BaseLoader._get_edgelist(vertices, edges, is_hetero, e_attr_types) if not is_hetero: # Deal with edgelist first @@ -1569,6 +1595,19 @@ def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, # Reformat as a graph. # Need to have a pair of tables for edges and vertices. vertices, edges = raw + # Dedupe vertices if there is is_seed column as the same vertex might be seed and non-seed at the same time + if not is_hetero: + if "is_seed" in vertices.columns: + seeds = set(vertices[vertices.is_seed.astype(int).astype(bool)]["vid"]) + vertices.loc[vertices.vid.isin(seeds), "is_seed"] = True + vertices.drop_duplicates(subset="vid", inplace=True, ignore_index=True) + else: + for vtype in vertices: + df = vertices[vtype] + if "is_seed" in df.columns: + seeds = set(df[df.is_seed.astype(int).astype(bool)]["vid"]) + df.loc[df.vid.isin(seeds), "is_seed"] = True + df.drop_duplicates(subset="vid", inplace=True, ignore_index=True) edgelist = BaseLoader._get_edgelist(vertices, edges, is_hetero, e_attr_types) if not is_hetero: # Deal with edgelist first @@ -2030,6 +2069,7 @@ def __init__( self._etypes = list(self._e_schema.keys()) self._vtypes = sorted(self._vtypes) self._etypes = sorted(self._etypes) + # Resolve seeds if v_seed_types: if isinstance(v_seed_types, list): self._seed_types = v_seed_types @@ -2041,10 +2081,15 @@ def __init__( self._seed_types = list(filter_by.keys()) else: self._seed_types = self._vtypes + if set(self._seed_types) - set(self._vtypes): + raise ValueError("Seed type has to be one of the vertex types to retrieve") - # Resolve seeds if batch_size: - # If batch_size is given, calculate the number of batches + # batch size takes precedence over number of batches + self.batch_size = batch_size + self.num_batches = None + else: + # If number of batches is given, calculate batch size if not filter_by: num_vertices = sum(self._graph.getVertexCount(self._seed_types).values()) elif isinstance(filter_by, str): @@ -2059,17 +2104,11 @@ def __init__( ) else: raise ValueError("filter_by should be None, attribute name, or dict of {type name: attribute name}.") - self.num_batches = math.ceil(num_vertices / batch_size) - else: - # Otherwise, take the number of batches as is. + self.batch_size = math.ceil(num_vertices / num_batches) self.num_batches = num_batches # Initialize parameters for the query - if batch_size: - self._payload["batch_size"] = batch_size - self._payload["num_batches"] = self.num_batches self._payload["num_neighbors"] = num_neighbors self._payload["num_hops"] = num_hops - self._payload["num_heap_inserts"] = self.num_heap_inserts if filter_by: if isinstance(filter_by, str): self._payload["filter_by"] = filter_by @@ -2107,6 +2146,23 @@ def _install_query(self, force: bool = False): if isinstance(self.v_in_feats, dict) or isinstance(self.e_in_feats, dict): # Multiple vertex types print_query_seed = "" + for idx, vtype in enumerate(self._seed_types): + v_attr_names = ( + self.v_in_feats.get(vtype, []) + + self.v_out_labels.get(vtype, []) + + self.v_extra_feats.get(vtype, []) + ) + v_attr_types = self._v_schema[vtype] + print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) + print_query_seed += """ + {} s.type == "{}" THEN + @@v_batch += (s.type + delimiter + stringify(getvid(s)) {} + delimiter + "1\\n")"""\ + .format("IF" if idx==0 else "ELSE IF", + vtype, + "+ delimiter + " + print_attr if v_attr_names else "") + print_query_seed += """ + END""" + query_replace["{SEEDVERTEXATTRS}"] = print_query_seed print_query_other = "" for idx, vtype in enumerate(self._vtypes): v_attr_names = ( @@ -2115,20 +2171,15 @@ def _install_query(self, force: bool = False): + self.v_extra_feats.get(vtype, []) ) v_attr_types = self._v_schema[vtype] - if v_attr_names: - print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query_seed += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + {} + delimiter + "1\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype, print_attr) - print_query_other += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + {} + delimiter + "0\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype, print_attr) - else: - print_query_seed += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + "1\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype) - print_query_other += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + "0\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype) - print_query_seed += "END" - print_query_other += "END" - query_replace["{SEEDVERTEXATTRS}"] = print_query_seed + print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) + print_query_other += """ + {} s.type == "{}" THEN + @@v_batch += (s.type + delimiter + stringify(getvid(s)) {} + delimiter + "0\\n")"""\ + .format("IF" if idx==0 else "ELSE IF", + vtype, + "+ delimiter + " + print_attr if v_attr_names else "") + print_query_other += """ + END""" query_replace["{OTHERVERTEXATTRS}"] = print_query_other # Multiple edge types print_query = "" @@ -2139,44 +2190,36 @@ def _install_query(self, force: bool = False): + self.e_extra_feats.get(etype, []) ) e_attr_types = self._e_schema[etype] - if e_attr_names: - print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) - print_query += '{} e.type == "{}" THEN \n @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", etype, print_attr) - else: - print_query += '{} e.type == "{}" THEN \n @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", etype) - print_query += "END" + print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) + print_query += """ + {} e.type == "{}" THEN + @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) {} + "\\n")"""\ + .format("IF" if idx==0 else "ELSE IF", + etype, + "+ delimiter + " + print_attr if e_attr_names else "") + print_query += """ + END""" query_replace["{EDGEATTRS}"] = print_query else: # Ignore vertex types v_attr_names = self.v_in_feats + self.v_out_labels + self.v_extra_feats v_attr_types = next(iter(self._v_schema.values())) - if v_attr_names: - print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query = '@@v_batch += (stringify(getvid(s)) + delimiter + {} + delimiter + "1\\n")'.format( - print_attr - ) - query_replace["{SEEDVERTEXATTRS}"] = print_query - print_query = '@@v_batch += (stringify(getvid(s)) + delimiter + {} + delimiter + "0\\n")'.format( - print_attr - ) - query_replace["{OTHERVERTEXATTRS}"] = print_query - else: - print_query = '@@v_batch += (stringify(getvid(s)) + delimiter + "1\\n")' - query_replace["{SEEDVERTEXATTRS}"] = print_query - print_query = '@@v_batch += (stringify(getvid(s)) + delimiter + "0\\n")' - query_replace["{OTHERVERTEXATTRS}"] = print_query + print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) + print_query = '@@v_batch += (stringify(getvid(s)) {} + delimiter + "1\\n")'.format( + "+ delimiter + " + print_attr if v_attr_names else "" + ) + query_replace["{SEEDVERTEXATTRS}"] = print_query + print_query = '@@v_batch += (stringify(getvid(s)) {} + delimiter + "0\\n")'.format( + "+ delimiter + " + print_attr if v_attr_names else "" + ) + query_replace["{OTHERVERTEXATTRS}"] = print_query # Ignore edge types e_attr_names = self.e_in_feats + self.e_out_labels + self.e_extra_feats e_attr_types = next(iter(self._e_schema.values())) - if e_attr_names: - print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) - print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + "\\n")'.format( - print_attr - ) - else: - print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + "\\n")' + print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) + print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) {} + "\\n")'.format( + " + delimiter + " + print_attr if e_attr_names else "" + ) query_replace["{EDGEATTRS}"] = print_query # Install query query_path = os.path.join( @@ -2185,15 +2228,21 @@ def _install_query(self, force: bool = False): "dataloaders", "neighbor_loader.gsql", ) - return install_query_file(self._graph, query_path, query_replace, force=force, distributed=self.distributed_query) + sub_query_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "gsql", + "dataloaders", + "neighbor_loader_sub.gsql", + ) + return install_query_files(self._graph, [sub_query_path, query_path], query_replace, force=force, distributed=[False, self.distributed_query]) def _start(self) -> None: # Create task and result queues - self._read_task_q = Queue(self.buffer_size * 2) + self._read_task_q = Queue(self.buffer_size) self._data_q = Queue(self.buffer_size) self._exit_event = Event() - self._start_request(True, "both") + self._start_request(True) # Start reading thread. if not self.is_hetero: @@ -2210,26 +2259,25 @@ def _start(self) -> None: v_attr_types[vtype]["is_seed"] = "bool" e_attr_types = self._e_schema self._reader = Thread( - target=self._read_data, - args=( - self._exit_event, - self._read_task_q, - self._data_q, - "graph", - self.output_format, - self.v_in_feats, - self.v_out_labels, - v_extra_feats, - v_attr_types, - self.e_in_feats, - self.e_out_labels, - self.e_extra_feats, - e_attr_types, - self.add_self_loop, - self.delimiter, - True, - self.is_hetero, - self.callback_fn + target=self._read_graph_data, + kwargs=dict( + exit_event = self._exit_event, + in_q = self._read_task_q, + out_q = self._data_q, + batch_size = self.batch_size, + out_format = self.output_format, + v_in_feats = self.v_in_feats, + v_out_labels = self.v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = self.e_in_feats, + e_out_labels = self.e_out_labels, + e_extra_feats = self.e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = self.add_self_loop, + delimiter = self.delimiter, + is_hetero = self.is_hetero, + callback_fn = self.callback_fn ), ) self._reader.start() @@ -2266,7 +2314,6 @@ def fetch(self, vertices: list) -> None: _payload = {} _payload["v_types"] = self._payload["v_types"] _payload["e_types"] = self._payload["e_types"] - _payload["num_batches"] = 1 _payload["num_neighbors"] = self._payload["num_neighbors"] _payload["num_hops"] = self._payload["num_hops"] _payload["delimiter"] = self._payload["delimiter"] @@ -2292,11 +2339,15 @@ def fetch(self, vertices: list) -> None: v_attr_types[vtype]["is_seed"] = "bool" v_attr_types[vtype]["primary_id"] = "str" e_attr_types = self._e_schema - i = resp[0] - data = self._parse_data( - raw = (i["vertex_batch"], i["edge_batch"]), - in_format = "graph", - out_format = self.output_format, + vertex_batch = set() + edge_batch = set() + for i in resp: + if "pids" in i: + break + vertex_batch.update(i["vertex_batch"].splitlines()) + edge_batch.update(i["edge_batch"].splitlines()) + data = self._parse_graph_data_to_df( + raw = (vertex_batch, edge_batch), v_in_feats = self.v_in_feats, v_out_labels = self.v_out_labels, v_extra_feats = v_extra_feats, @@ -2305,14 +2356,119 @@ def fetch(self, vertices: list) -> None: e_out_labels = self.e_out_labels, e_extra_feats = self.e_extra_feats, e_attr_types = e_attr_types, - add_self_loop = self.add_self_loop, delimiter = self.delimiter, - reindex = True, primary_id = i["pids"], is_hetero = self.is_hetero, - callback_fn = self.callback_fn ) - # Return data + if self.output_format == "dataframe" or self.output_format== "df": + vertices, edges = data + if not self.is_hetero: + for column in vertices.columns: + vertices[column] = pd.to_numeric(vertices[column], errors="ignore") + for column in edges.columns: + edges[column] = pd.to_numeric(edges[column], errors="ignore") + else: + for key in vertices: + for column in vertices[key].columns: + vertices[key][column] = pd.to_numeric(vertices[key][column], errors="ignore") + for key in edges: + for column in edges[key].columns: + edges[key][column] = pd.to_numeric(edges[key][column], errors="ignore") + data = (vertices, edges) + elif self.output_format == "pyg": + try: + import torch + except ImportError: + raise ImportError( + "PyTorch is not installed. Please install it to use PyG or DGL output." + ) + try: + import torch_geometric as pyg + except ImportError: + raise ImportError( + "PyG is not installed. Please install PyG to use PyG format." + ) + data = BaseLoader._parse_df_to_pyg( + raw = data, + v_in_feats = self.v_in_feats, + v_out_labels = self.v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = self.e_in_feats, + e_out_labels = self.e_out_labels, + e_extra_feats = self.e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = self.add_self_loop, + is_hetero = self.is_hetero, + torch = torch, + pyg = pyg + ) + elif self.output_format == "dgl": + try: + import torch + except ImportError: + raise ImportError( + "PyTorch is not installed. Please install it to use PyG or DGL output." + ) + try: + import dgl + except ImportError: + raise ImportError( + "DGL is not installed. Please install DGL to use DGL format." + ) + data = BaseLoader._parse_df_to_dgl( + raw = data, + v_in_feats = self.v_in_feats, + v_out_labels = self.v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = self.e_in_feats, + e_out_labels = self.e_out_labels, + e_extra_feats = self.e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = self.add_self_loop, + is_hetero = self.is_hetero, + torch = torch, + dgl= dgl + ) + elif self.output_format == "spektral" and self.is_hetero==False: + try: + import tensorflow as tf + except ImportError: + raise ImportError( + "Tensorflow is not installed. Please install it to use spektral output." + ) + try: + import scipy + except ImportError: + raise ImportError( + "scipy is not installed. Please install it to use spektral output." + ) + try: + import spektral + except ImportError: + raise ImportError( + "Spektral is not installed. Please install it to use spektral output." + ) + data = BaseLoader._parse_df_to_spektral( + raw = data, + v_in_feats = self.v_in_feats, + v_out_labels = self.v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = self.e_in_feats, + e_out_labels = self.e_out_labels, + e_extra_feats = self.e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = self.add_self_loop, + is_hetero = self.is_hetero, + scipy = scipy, + spektral = spektral + ) + else: + raise NotImplementedError + if self.callback_fn: + data = self.callback_fn(data) return data @@ -3196,7 +3352,7 @@ def __init__( self.batch_size = batch_size self.num_batches = None else: - # If number of batches is given, calculate batch size + # If number of batches is given, calculate batch size if filter_by: num_edges = 0 for e_type in self._etypes: diff --git a/pyTigerGraph/gds/gsql/dataloaders/graph_loader.gsql b/pyTigerGraph/gds/gsql/dataloaders/graph_loader.gsql index 5e0d5f61..fca2a83e 100644 --- a/pyTigerGraph/gds/gsql/dataloaders/graph_loader.gsql +++ b/pyTigerGraph/gds/gsql/dataloaders/graph_loader.gsql @@ -6,7 +6,7 @@ CREATE QUERY graph_loader_{QUERYSUFFIX}( BOOL shuffle=FALSE, INT num_chunks=2, STRING kafka_address="", - STRING kafka_topic, + STRING kafka_topic="", INT kafka_topic_partitions=1, STRING kafka_max_size="104857600", INT kafka_timeout=300000, diff --git a/pyTigerGraph/gds/gsql/dataloaders/neighbor_loader.gsql b/pyTigerGraph/gds/gsql/dataloaders/neighbor_loader.gsql index 8be04db6..d1cd3add 100644 --- a/pyTigerGraph/gds/gsql/dataloaders/neighbor_loader.gsql +++ b/pyTigerGraph/gds/gsql/dataloaders/neighbor_loader.gsql @@ -1,7 +1,5 @@ CREATE QUERY neighbor_loader_{QUERYSUFFIX}( SET input_vertices, - INT batch_size, - INT num_batches=1, INT num_neighbors=10, INT num_hops=2, BOOL shuffle=FALSE, @@ -10,8 +8,9 @@ CREATE QUERY neighbor_loader_{QUERYSUFFIX}( SET e_types, SET seed_types, STRING delimiter, + INT num_chunks=2, STRING kafka_address="", - STRING kafka_topic, + STRING kafka_topic="", INT kafka_topic_partitions=1, STRING kafka_max_size="104857600", INT kafka_timeout=300000, @@ -26,8 +25,7 @@ CREATE QUERY neighbor_loader_{QUERYSUFFIX}( STRING ssl_endpoint_identification_algorithm="", STRING sasl_kerberos_service_name="", STRING sasl_kerberos_keytab="", - STRING sasl_kerberos_principal="", - INT num_heap_inserts = 10 + STRING sasl_kerberos_principal="" ) SYNTAX V1 { /* This query generates the neighborhood subgraphs of given seed vertices (i.e., `input_vertices`). @@ -55,148 +53,138 @@ CREATE QUERY neighbor_loader_{QUERYSUFFIX}( sasl_password : SASL password for Kafka. ssl_ca_location: Path to CA certificate for verifying the Kafka broker key. */ - TYPEDEF TUPLE ID_Tuple; - INT num_vertices; - INT kafka_errcode; SumAccum @tmp_id; - SumAccum @@kafka_error; - UINT producer; - INT batch_s; - OrAccum @prev_sampled; - # Initialize Kafka producer - IF kafka_address != "" THEN - producer = init_kafka_producer( - kafka_address, kafka_max_size, security_protocol, - sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, - ssl_certificate_location, ssl_key_location, ssl_key_password, - ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, - sasl_kerberos_keytab, sasl_kerberos_principal); - END; - - # Shuffle vertex ID if needed + # If getting all vertices of given types IF input_vertices.size()==0 THEN start = {seed_types}; - IF filter_by IS NOT NULL THEN - start = SELECT s FROM start:s WHERE s.getAttr(filter_by, "BOOL"); - END; + # Filter seeds if needed + seeds = SELECT s + FROM start:s + WHERE filter_by is NULL OR s.getAttr(filter_by, "BOOL"); + # Shuffle vertex ID if needed IF shuffle THEN - num_vertices = start.size(); + INT num_vertices = seeds.size(); res = SELECT s - FROM start:s - POST-ACCUM s.@tmp_id = floor(rand()*num_vertices); + FROM seeds:s + POST-ACCUM s.@tmp_id = floor(rand()*num_vertices) + LIMIT 1; ELSE res = SELECT s - FROM start:s - POST-ACCUM s.@tmp_id = getvid(s); - END; - END; - IF batch_size IS NULL THEN - batch_s = ceil(res.size()/num_batches); - ELSE - batch_s = batch_size; - END; - # Generate subgraphs - FOREACH batch_id IN RANGE[0, num_batches-1] DO - SumAccum @@v_batch; - SumAccum @@e_batch; - SetAccum @@printed_vertices; - SetAccum @@printed_edges; - SetAccum @@seeds; - # Get seeds - IF input_vertices.size()==0 THEN - start = {seed_types}; - HeapAccum (1, tmp_id ASC) @@batch_heap; - @@batch_heap.resize(batch_s); - IF filter_by IS NOT NULL THEN - FOREACH iter IN RANGE[0,num_heap_inserts-1] DO - _verts = SELECT s FROM start:s - WHERE s.@tmp_id % num_heap_inserts == iter AND NOT s.@prev_sampled AND s.getAttr(filter_by, "BOOL") - POST-ACCUM @@batch_heap += ID_Tuple(s.@tmp_id, s); - END; - FOREACH elem IN @@batch_heap DO - @@seeds += elem.v; - END; - seeds = {@@seeds}; - seeds = SELECT s - FROM seeds:s - POST-ACCUM - s.@prev_sampled += TRUE, - {SEEDVERTEXATTRS}, - @@printed_vertices += s; - ELSE - FOREACH iter IN RANGE[0,num_heap_inserts-1] DO - _verts = SELECT s FROM start:s - WHERE s.@tmp_id % num_heap_inserts == iter AND NOT s.@prev_sampled - POST-ACCUM @@batch_heap += ID_Tuple(s.@tmp_id, s); - END; - FOREACH elem IN @@batch_heap DO - @@seeds += elem.v; - END; - seeds = {@@seeds}; - seeds = SELECT s - FROM start:s - POST-ACCUM - s.@prev_sampled += TRUE, - {SEEDVERTEXATTRS}, - @@printed_vertices += s; - END; - ELSE - start = input_vertices; - seeds = SELECT s - FROM start:s - POST-ACCUM - {SEEDVERTEXATTRS}, - @@printed_vertices += s; - END; - # Get neighbors of seeeds - FOREACH i IN RANGE[1, num_hops] DO - seeds = SELECT t - FROM seeds:s -(e_types:e)- v_types:t - SAMPLE num_neighbors EDGE WHEN s.outdegree() >= 1 - ACCUM - IF NOT @@printed_edges.contains(e) THEN - {EDGEATTRS}, - @@printed_edges += e - END; - attr = SELECT s - FROM seeds:s - POST-ACCUM - IF NOT @@printed_vertices.contains(s) THEN - {OTHERVERTEXATTRS}, - @@printed_vertices += s - END; + FROM seeds:s + POST-ACCUM s.@tmp_id = getvid(s) + LIMIT 1; END; + + # Export data + # If using kafka to export IF kafka_address != "" THEN - # Write to kafka - kafka_errcode = write_to_kafka(producer, kafka_topic, batch_id%kafka_topic_partitions, "vertex_batch_" + stringify(batch_id), @@v_batch); - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending vertex batch " + stringify(batch_id) + ": "+ stringify(kafka_errcode) + "\n"); + SumAccum @@kafka_error; + + # Initialize Kafka producer + UINT producer = init_kafka_producer( + kafka_address, kafka_max_size, security_protocol, + sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, + ssl_certificate_location, ssl_key_location, ssl_key_password, + ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, + sasl_kerberos_keytab, sasl_kerberos_principal); + + FOREACH chunk IN RANGE[0, num_chunks-1] DO + res = SELECT s + FROM seeds:s + WHERE s.@tmp_id % num_chunks == chunk + POST-ACCUM + LIST msg = neighbor_loader_sub_{QUERYSUFFIX}(s, delimiter, num_hops, num_neighbors, e_types, v_types), + BOOL is_first=True, + FOREACH i in msg DO + IF is_first THEN + INT kafka_errcode = write_to_kafka(producer, kafka_topic, getvid(s)%kafka_topic_partitions, "vertex_batch_" + stringify(getvid(s)), i), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending vertex batch for " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\\n") + END, + is_first = False + ELSE + INT kafka_errcode = write_to_kafka(producer, kafka_topic, getvid(s)%kafka_topic_partitions, "edge_batch_" + stringify(getvid(s)), i), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending edge batch for " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\\n") + END + END + END + LIMIT 1; + END; + + FOREACH i IN RANGE[0, kafka_topic_partitions-1] DO + INT kafka_errcode = write_to_kafka(producer, kafka_topic, i, "STOP", ""); + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending STOP signal to topic partition " + stringify(i) + ": " + stringify(kafka_errcode) + "\n"); + END; END; - kafka_errcode = write_to_kafka(producer, kafka_topic, batch_id%kafka_topic_partitions, "edge_batch_" + stringify(batch_id), @@e_batch); + + INT kafka_errcode = close_kafka_producer(producer, kafka_timeout); IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending edge batch " + stringify(batch_id) + ": "+ stringify(kafka_errcode) + "\n"); + @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n"); END; + PRINT @@kafka_error as kafkaError; + # Else return as http response ELSE - # Add to response - IF input_vertices.size()==0 THEN - PRINT @@v_batch AS vertex_batch, @@e_batch AS edge_batch; - ELSE - MapAccum @@id_map; - source = @@printed_vertices; - res = - SELECT s - FROM source:s - POST-ACCUM @@id_map += (getvid(s) -> s); - PRINT @@v_batch AS vertex_batch, @@e_batch AS edge_batch, @@id_map AS pids; + FOREACH chunk IN RANGE[0, num_chunks-1] DO + MapAccum @@v_batch; + MapAccum @@e_batch; + + res = SELECT s + FROM seeds:s + WHERE s.@tmp_id % num_chunks == chunk + POST-ACCUM + LIST msg = neighbor_loader_sub_{QUERYSUFFIX}(s, delimiter, num_hops, num_neighbors, e_types, v_types), + BOOL is_first=True, + FOREACH i in msg DO + IF is_first THEN + @@v_batch += (getvid(s) -> i), + is_first = False + ELSE + @@e_batch += (getvid(s) -> i) + END + END + LIMIT 1; + + FOREACH (k,v) IN @@v_batch DO + PRINT v as vertex_batch, @@e_batch.get(k) as edge_batch; + END; END; - END; - END; - IF kafka_address != "" THEN - kafka_errcode = close_kafka_producer(producer, kafka_timeout); - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n"); END; - PRINT @@kafka_error as kafkaError; + # Else get given vertices. + ELSE + MapAccum @@v_batch; + MapAccum @@e_batch; + MapAccum @@id_map; + + seeds = input_vertices; + res = SELECT s + FROM seeds:s + POST-ACCUM + LIST msg = neighbor_loader_sub_{QUERYSUFFIX}(s, delimiter, num_hops, num_neighbors, e_types, v_types), + BOOL is_first=True, + FOREACH i in msg DO + IF is_first THEN + @@v_batch += (getvid(s) -> i), + is_first = False + ELSE + @@e_batch += (getvid(s) -> i) + END + END, + @@id_map += (getvid(s) -> s) + LIMIT 1; + + FOREACH (k,v) IN @@v_batch DO + PRINT v as vertex_batch, @@e_batch.get(k) as edge_batch; + END; + + FOREACH hop IN RANGE[1, num_hops] DO + seeds = SELECT t + FROM seeds:s -(e_types:e)- v_types:t + POST-ACCUM + @@id_map += (getvid(t) -> t); + END; + PRINT @@id_map AS pids; END; -} \ No newline at end of file +} diff --git a/pyTigerGraph/gds/gsql/dataloaders/neighbor_loader_sub.gsql b/pyTigerGraph/gds/gsql/dataloaders/neighbor_loader_sub.gsql new file mode 100644 index 00000000..c944aca5 --- /dev/null +++ b/pyTigerGraph/gds/gsql/dataloaders/neighbor_loader_sub.gsql @@ -0,0 +1,38 @@ +CREATE QUERY neighbor_loader_sub_{QUERYSUFFIX} (VERTEX v, STRING delimiter, INT num_hops, INT num_neighbors, SET e_types, SET v_types) +RETURNS (ListAccum) +SYNTAX V1 +{ + SumAccum @@v_batch; + SumAccum @@e_batch; + SetAccum @@printed_vertices; + SetAccum @@printed_edges; + ListAccum @@ret; + + start = {v}; + res = SELECT s + FROM start:s + POST-ACCUM + @@printed_vertices += s, + {SEEDVERTEXATTRS}; + + FOREACH i IN RANGE[1, num_hops] DO + start = SELECT t + FROM start:s -(e_types:e)- v_types:t + SAMPLE num_neighbors EDGE WHEN s.outdegree() >= 1 + ACCUM + IF NOT @@printed_edges.contains(e) THEN + @@printed_edges += e, + {EDGEATTRS} + END; + start = SELECT s + FROM start:s + POST-ACCUM + IF NOT @@printed_vertices.contains(s) THEN + @@printed_vertices += s, + {OTHERVERTEXATTRS} + END; + END; + @@ret += @@v_batch; + @@ret += @@e_batch; + RETURN @@ret; +} diff --git a/tests/test_gds_NeighborLoader.py b/tests/test_gds_NeighborLoader.py index 2bb9a7c0..fa509a91 100644 --- a/tests/test_gds_NeighborLoader.py +++ b/tests/test_gds_NeighborLoader.py @@ -26,13 +26,10 @@ def test_init(self): shuffle=True, filter_by="train_mask", output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, - kafka_address="kafka:9092", + kafka_address="kafka:9092" ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 9) + self.assertIsNone(loader.num_batches) def test_iterate_pyg(self): loader = NeighborLoader( @@ -46,27 +43,27 @@ def test_iterate_pyg(self): shuffle=True, filter_by="train_mask", output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, - kafka_address="kafka:9092", + kafka_address="kafka:9092" ) - for epoch in range(2): - with self.subTest(i=epoch): - num_batches = 0 - for data in loader: - # print(num_batches, data) - self.assertIsInstance(data, pygData) - self.assertIn("x", data) - self.assertIn("y", data) - self.assertIn("train_mask", data) - self.assertIn("val_mask", data) - self.assertIn("test_mask", data) - self.assertIn("is_seed", data) - self.assertGreater(data["x"].shape[0], 0) - self.assertGreater(data["edge_index"].shape[1], 0) - num_batches += 1 - self.assertEqual(num_batches, 9) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + self.assertIsInstance(data, pygData) + self.assertIn("x", data) + self.assertIn("y", data) + self.assertIn("train_mask", data) + self.assertIn("val_mask", data) + self.assertIn("test_mask", data) + self.assertIn("is_seed", data) + self.assertGreater(data["x"].shape[0], 0) + self.assertGreater(data["edge_index"].shape[1], 0) + num_batches += 1 + batch_sizes.append(int(data["is_seed"].sum())) + self.assertEqual(num_batches, 9) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) def test_iterate_stop_pyg(self): loader = NeighborLoader( @@ -80,9 +77,6 @@ def test_iterate_stop_pyg(self): shuffle=True, filter_by="train_mask", output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) for epoch in range(2): @@ -109,33 +103,6 @@ def test_iterate_stop_pyg(self): rq_id = self.conn.getRunningQueries()["results"] self.assertEqual(len(rq_id), 0) - def test_whole_graph_pyg(self): - loader = NeighborLoader( - graph=self.conn, - v_in_feats=["x"], - v_out_labels=["y"], - v_extra_feats=["train_mask", "val_mask", "test_mask"], - num_batches=1, - num_neighbors=10, - num_hops=2, - shuffle=False, - filter_by="train_mask", - output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, - kafka_address="kafka:9092", - ) - data = loader.data - # print(data) - self.assertIsInstance(data, pygData) - self.assertIn("x", data) - self.assertIn("y", data) - self.assertIn("train_mask", data) - self.assertIn("val_mask", data) - self.assertIn("test_mask", data) - self.assertIn("is_seed", data) - def test_edge_attr(self): loader = NeighborLoader( graph=self.conn, @@ -150,14 +117,12 @@ def test_edge_attr(self): shuffle=True, filter_by="train_mask", output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, - kafka_address="kafka:9092", + kafka_address="kafka:9092" ) for epoch in range(2): with self.subTest(i=epoch): num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygData) @@ -170,7 +135,11 @@ def test_edge_attr(self): self.assertIn("edge_feat", data) self.assertIn("is_train", data) num_batches += 1 + batch_sizes.append(int(data["is_seed"].sum())) self.assertEqual(num_batches, 9) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) def test_sasl_plaintext(self): loader = NeighborLoader( @@ -307,12 +276,9 @@ def test_init(self): shuffle=True, filter_by="train_mask", output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 9) + self.assertIsNone(loader.num_batches) def test_iterate_pyg(self): loader = NeighborLoader( @@ -326,11 +292,9 @@ def test_iterate_pyg(self): shuffle=True, filter_by="train_mask", output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygData) @@ -343,33 +307,11 @@ def test_iterate_pyg(self): self.assertGreater(data["x"].shape[0], 0) self.assertGreater(data["edge_index"].shape[1], 0) num_batches += 1 + batch_sizes.append(int(data["is_seed"].sum())) self.assertEqual(num_batches, 9) - - def test_whole_graph_pyg(self): - loader = NeighborLoader( - graph=self.conn, - v_in_feats=["x"], - v_out_labels=["y"], - v_extra_feats=["train_mask", "val_mask", "test_mask"], - num_batches=1, - num_neighbors=10, - num_hops=2, - shuffle=False, - filter_by="train_mask", - output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, - ) - data = loader.data - # print(data) - self.assertIsInstance(data, pygData) - self.assertIn("x", data) - self.assertIn("y", data) - self.assertIn("train_mask", data) - self.assertIn("val_mask", data) - self.assertIn("test_mask", data) - self.assertIn("is_seed", data) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) def test_edge_attr(self): loader = NeighborLoader( @@ -385,13 +327,11 @@ def test_edge_attr(self): shuffle=True, filter_by="train_mask", output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) for epoch in range(2): with self.subTest(i=epoch): num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygData) @@ -404,7 +344,11 @@ def test_edge_attr(self): self.assertIn("edge_feat", data) self.assertIn("is_train", data) num_batches += 1 + batch_sizes.append(int(data["is_seed"].sum())) self.assertEqual(num_batches, 9) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) def test_fetch(self): loader = NeighborLoader( @@ -415,26 +359,26 @@ def test_fetch(self): batch_size=16, num_neighbors=10, num_hops=2, - shuffle=True, + shuffle=False, filter_by="train_mask", output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) data = loader.fetch( [ - {"primary_id": "100", "type": "Paper"}, + {"primary_id": "60", "type": "Paper"}, {"primary_id": "55", "type": "Paper"}, ] ) + # print(data) + # print(data["primary_id"]) + # print(data["is_seed"]) self.assertIn("primary_id", data) self.assertGreater(data["x"].shape[0], 2) self.assertGreater(data["edge_index"].shape[1], 0) - self.assertIn("100", data["primary_id"]) + self.assertIn("60", data["primary_id"]) self.assertIn("55", data["primary_id"]) for i, d in enumerate(data["primary_id"]): - if d == "100" or d == "55": + if d == "60" or d == "55": self.assertTrue(data["is_seed"][i].item()) else: self.assertFalse(data["is_seed"][i].item()) @@ -452,23 +396,20 @@ def test_fetch_delimiter(self): delimiter="$|", filter_by="train_mask", output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) data = loader.fetch( [ - {"primary_id": "100", "type": "Paper"}, + {"primary_id": "60", "type": "Paper"}, {"primary_id": "55", "type": "Paper"}, ] ) self.assertIn("primary_id", data) self.assertGreater(data["x"].shape[0], 2) self.assertGreater(data["edge_index"].shape[1], 0) - self.assertIn("100", data["primary_id"]) + self.assertIn("60", data["primary_id"]) self.assertIn("55", data["primary_id"]) for i, d in enumerate(data["primary_id"]): - if d == "100" or d == "55": + if d == "60" or d == "55": self.assertTrue(data["is_seed"][i].item()) else: self.assertFalse(data["is_seed"][i].item()) @@ -485,9 +426,6 @@ def test_iterate_spektral(self): shuffle=True, filter_by="train_mask", output_format="spektral", - add_self_loop=False, - loader_id=None, - buffer_size=4 ) num_batches = 0 for data in loader: @@ -502,32 +440,6 @@ def test_iterate_spektral(self): num_batches += 1 self.assertEqual(num_batches, 9) - def test_whole_graph_spektral(self): - loader = NeighborLoader( - graph=self.conn, - v_in_feats=["x"], - v_out_labels=["y"], - v_extra_feats=["train_mask", "val_mask", "test_mask"], - num_batches=1, - num_neighbors=10, - num_hops=2, - shuffle=False, - filter_by="train_mask", - output_format="spektral", - add_self_loop=False, - loader_id=None, - buffer_size=4, - ) - data = loader.data - # print(data) - # self.assertIsInstance(data, spData) - self.assertIn("x", data) - self.assertIn("y", data) - self.assertIn("train_mask", data) - self.assertIn("val_mask", data) - self.assertIn("test_mask", data) - self.assertIn("is_seed", data) - def test_reinstall_query(self): loader = NeighborLoader( graph=self.conn, @@ -540,9 +452,6 @@ def test_reinstall_query(self): shuffle=True, filter_by="train_mask", output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) query_name = loader.query_name @@ -565,12 +474,9 @@ def test_init(self): num_hops=2, shuffle=True, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 18) + self.assertIsNone(loader.num_batches) def test_whole_graph_df(self): loader = NeighborLoader( @@ -583,14 +489,11 @@ def test_whole_graph_df(self): num_hops=2, shuffle=False, output_format="dataframe", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) data = loader.data - self.assertTupleEqual(data[0]["v0"].shape, (76, 7)) - self.assertTupleEqual(data[0]["v1"].shape, (110, 3)) - self.assertTupleEqual(data[0]["v2"].shape, (100, 3)) + self.assertTupleEqual(data[0]["v0"].shape, (152, 7)) + self.assertTupleEqual(data[0]["v1"].shape, (220, 3)) + self.assertTupleEqual(data[0]["v2"].shape, (200, 3)) self.assertTrue( data[1]["v0v0"].shape[0] > 0 and data[1]["v0v0"].shape[0] <= 710 ) @@ -621,9 +524,6 @@ def test_whole_graph_pyg(self): num_hops=2, shuffle=False, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) data = loader.data # print(data) @@ -668,57 +568,75 @@ def test_iterate_pyg(self): v_in_feats={"v0": ["x"], "v1": ["x"], "v2": ["x"]}, v_out_labels={"v0": ["y"]}, v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + v_seed_types=["v2"], batch_size=16, num_neighbors=10, num_hops=2, shuffle=False, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygHeteroData) - self.assertGreater(data["v0"]["x"].shape[0], 0) - self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["y"].shape[0]) - self.assertEqual( - data["v0"]["x"].shape[0], data["v0"]["train_mask"].shape[0] - ) - self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["test_mask"].shape[0]) - self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["is_seed"].shape[0]) - self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["val_mask"].shape[0]) - self.assertGreater(data["v1"]["x"].shape[0], 0) - self.assertEqual(data["v1"]["x"].shape[0], data["v1"]["is_seed"].shape[0]) self.assertGreater(data["v2"]["x"].shape[0], 0) self.assertEqual(data["v2"]["x"].shape[0], data["v2"]["is_seed"].shape[0]) + batch_sizes.append(int(data["v2"]["is_seed"].sum())) + if "v1" in data.node_types: + self.assertGreater(data["v1"]["x"].shape[0], 0) + self.assertEqual(data["v1"]["x"].shape[0], data["v1"]["is_seed"].shape[0]) + if "v2" in data.node_types: + self.assertGreater(data["v0"]["x"].shape[0], 0) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["y"].shape[0]) + self.assertEqual( + data["v0"]["x"].shape[0], data["v0"]["train_mask"].shape[0] + ) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["test_mask"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["is_seed"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["val_mask"].shape[0]) self.assertTrue( - data["v0v0"]["edge_index"].shape[1] > 0 - and data["v0v0"]["edge_index"].shape[1] <= 710 - ) - self.assertTrue( - data["v1v1"]["edge_index"].shape[1] > 0 - and data["v1v1"]["edge_index"].shape[1] <= 1044 - ) - self.assertTrue( - data["v1v2"]["edge_index"].shape[1] > 0 - and data["v1v2"]["edge_index"].shape[1] <= 1038 - ) - self.assertTrue( - data["v2v0"]["edge_index"].shape[1] > 0 - and data["v2v0"]["edge_index"].shape[1] <= 943 - ) - self.assertTrue( - data["v2v1"]["edge_index"].shape[1] > 0 - and data["v2v1"]["edge_index"].shape[1] <= 959 - ) - self.assertTrue( - data["v2v2"]["edge_index"].shape[1] > 0 - and data["v2v2"]["edge_index"].shape[1] <= 966 - ) + ('v0', 'v0v0', 'v0') in data.edge_types or + ('v1', 'v1v1', 'v1') in data.edge_types or + ('v1', 'v1v2', 'v2') in data.edge_types or + ('v2', 'v2v0', 'v0') in data.edge_types or + ('v2', 'v2v1', 'v1') in data.edge_types or + ('v2', 'v2v2', 'v2') in data.edge_types) + if ('v0', 'v0v0', 'v0') in data.edge_types: + self.assertTrue( + data['v0', 'v0v0', 'v0']["edge_index"].shape[1] > 0 + and data['v0', 'v0v0', 'v0']["edge_index"].shape[1] <= 710 + ) + if ('v1', 'v1v1', 'v1') in data.edge_types: + self.assertTrue( + data['v1', 'v1v1', 'v1']["edge_index"].shape[1] > 0 + and data['v1', 'v1v1', 'v1']["edge_index"].shape[1] <= 1044 + ) + if ('v1', 'v1v2', 'v2') in data.edge_types: + self.assertTrue( + data['v1', 'v1v2', 'v2']["edge_index"].shape[1] > 0 + and data['v1', 'v1v2', 'v2']["edge_index"].shape[1] <= 1038 + ) + if ('v2', 'v2v0', 'v0') in data.edge_types: + self.assertTrue( + data['v2', 'v2v0', 'v0']["edge_index"].shape[1] > 0 + and data['v2', 'v2v0', 'v0']["edge_index"].shape[1] <= 943 + ) + if ('v2', 'v2v1', 'v1') in data.edge_types: + self.assertTrue( + data['v2', 'v2v1', 'v1']["edge_index"].shape[1] > 0 + and data['v2', 'v2v1', 'v1']["edge_index"].shape[1] <= 959 + ) + if ('v2', 'v2v2', 'v2') in data.edge_types: + self.assertTrue( + data['v2', 'v2v2', 'v2']["edge_index"].shape[1] > 0 + and data['v2', 'v2v2', 'v2']["edge_index"].shape[1] <= 966 + ) num_batches += 1 - self.assertEqual(num_batches, 18) + self.assertEqual(num_batches, 7) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) def test_iterate_pyg_multichar_delimiter(self): loader = NeighborLoader( @@ -726,58 +644,76 @@ def test_iterate_pyg_multichar_delimiter(self): v_in_feats={"v0": ["x"], "v1": ["x"], "v2": ["x"]}, v_out_labels={"v0": ["y"]}, v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + v_seed_types=["v2"], batch_size=16, num_neighbors=10, num_hops=2, shuffle=False, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, delimiter="|$" ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygHeteroData) - self.assertGreater(data["v0"]["x"].shape[0], 0) - self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["y"].shape[0]) - self.assertEqual( - data["v0"]["x"].shape[0], data["v0"]["train_mask"].shape[0] - ) - self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["test_mask"].shape[0]) - self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["is_seed"].shape[0]) - self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["val_mask"].shape[0]) - self.assertGreater(data["v1"]["x"].shape[0], 0) - self.assertEqual(data["v1"]["x"].shape[0], data["v1"]["is_seed"].shape[0]) self.assertGreater(data["v2"]["x"].shape[0], 0) self.assertEqual(data["v2"]["x"].shape[0], data["v2"]["is_seed"].shape[0]) + batch_sizes.append(int(data["v2"]["is_seed"].sum())) + if "v1" in data.node_types: + self.assertGreater(data["v1"]["x"].shape[0], 0) + self.assertEqual(data["v1"]["x"].shape[0], data["v1"]["is_seed"].shape[0]) + if "v2" in data.node_types: + self.assertGreater(data["v0"]["x"].shape[0], 0) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["y"].shape[0]) + self.assertEqual( + data["v0"]["x"].shape[0], data["v0"]["train_mask"].shape[0] + ) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["test_mask"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["is_seed"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["val_mask"].shape[0]) self.assertTrue( - data["v0v0"]["edge_index"].shape[1] > 0 - and data["v0v0"]["edge_index"].shape[1] <= 710 - ) - self.assertTrue( - data["v1v1"]["edge_index"].shape[1] > 0 - and data["v1v1"]["edge_index"].shape[1] <= 1044 - ) - self.assertTrue( - data["v1v2"]["edge_index"].shape[1] > 0 - and data["v1v2"]["edge_index"].shape[1] <= 1038 - ) - self.assertTrue( - data["v2v0"]["edge_index"].shape[1] > 0 - and data["v2v0"]["edge_index"].shape[1] <= 943 - ) - self.assertTrue( - data["v2v1"]["edge_index"].shape[1] > 0 - and data["v2v1"]["edge_index"].shape[1] <= 959 - ) - self.assertTrue( - data["v2v2"]["edge_index"].shape[1] > 0 - and data["v2v2"]["edge_index"].shape[1] <= 966 - ) + ('v0', 'v0v0', 'v0') in data.edge_types or + ('v1', 'v1v1', 'v1') in data.edge_types or + ('v1', 'v1v2', 'v2') in data.edge_types or + ('v2', 'v2v0', 'v0') in data.edge_types or + ('v2', 'v2v1', 'v1') in data.edge_types or + ('v2', 'v2v2', 'v2') in data.edge_types) + if ('v0', 'v0v0', 'v0') in data.edge_types: + self.assertTrue( + data['v0', 'v0v0', 'v0']["edge_index"].shape[1] > 0 + and data['v0', 'v0v0', 'v0']["edge_index"].shape[1] <= 710 + ) + if ('v1', 'v1v1', 'v1') in data.edge_types: + self.assertTrue( + data['v1', 'v1v1', 'v1']["edge_index"].shape[1] > 0 + and data['v1', 'v1v1', 'v1']["edge_index"].shape[1] <= 1044 + ) + if ('v1', 'v1v2', 'v2') in data.edge_types: + self.assertTrue( + data['v1', 'v1v2', 'v2']["edge_index"].shape[1] > 0 + and data['v1', 'v1v2', 'v2']["edge_index"].shape[1] <= 1038 + ) + if ('v2', 'v2v0', 'v0') in data.edge_types: + self.assertTrue( + data['v2', 'v2v0', 'v0']["edge_index"].shape[1] > 0 + and data['v2', 'v2v0', 'v0']["edge_index"].shape[1] <= 943 + ) + if ('v2', 'v2v1', 'v1') in data.edge_types: + self.assertTrue( + data['v2', 'v2v1', 'v1']["edge_index"].shape[1] > 0 + and data['v2', 'v2v1', 'v1']["edge_index"].shape[1] <= 959 + ) + if ('v2', 'v2v2', 'v2') in data.edge_types: + self.assertTrue( + data['v2', 'v2v2', 'v2']["edge_index"].shape[1] > 0 + and data['v2', 'v2v2', 'v2']["edge_index"].shape[1] <= 966 + ) num_batches += 1 - self.assertEqual(num_batches, 18) + self.assertEqual(num_batches, 7) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) def test_fetch(self): loader = NeighborLoader( @@ -790,9 +726,6 @@ def test_fetch(self): num_hops=2, shuffle=False, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) data = loader.fetch( [{"primary_id": "10", "type": "v0"}, {"primary_id": "55", "type": "v0"}] @@ -820,9 +753,6 @@ def test_fetch_delimiter(self): shuffle=False, output_format="PyG", delimiter="$|", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) data = loader.fetch( [{"primary_id": "10", "type": "v0"}, {"primary_id": "55", "type": "v0"}] @@ -849,9 +779,6 @@ def test_metadata(self): num_hops=2, shuffle=False, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) test = (["v0", "v1", "v2"], @@ -865,19 +792,281 @@ def test_metadata(self): metadata = loader.metadata() self.assertEqual(test, metadata) + +class TestGDSHeteroNeighborLoaderKafka(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.conn = make_connection(graphname="hetero") + + def test_init(self): + loader = NeighborLoader( + graph=self.conn, + v_in_feats={"v0": ["x"], "v1": ["x"], "v2": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + batch_size=16, + num_neighbors=10, + num_hops=2, + shuffle=True, + output_format="PyG", + kafka_address="kafka:9092" + ) + self.assertTrue(is_query_installed(self.conn, loader.query_name)) + self.assertIsNone(loader.num_batches) + + def test_whole_graph_df(self): + loader = NeighborLoader( + graph=self.conn, + v_in_feats={"v0": ["x"], "v1": ["x"], "v2": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + num_batches=1, + num_neighbors=10, + num_hops=2, + shuffle=False, + output_format="dataframe", + kafka_address="kafka:9092" + ) + data = loader.data + self.assertTupleEqual(data[0]["v0"].shape, (152, 7)) + self.assertTupleEqual(data[0]["v1"].shape, (220, 3)) + self.assertTupleEqual(data[0]["v2"].shape, (200, 3)) + self.assertTrue( + data[1]["v0v0"].shape[0] > 0 and data[1]["v0v0"].shape[0] <= 710 + ) + self.assertTrue( + data[1]["v1v1"].shape[0] > 0 and data[1]["v1v1"].shape[0] <= 1044 + ) + self.assertTrue( + data[1]["v1v2"].shape[0] > 0 and data[1]["v1v2"].shape[0] <= 1038 + ) + self.assertTrue( + data[1]["v2v0"].shape[0] > 0 and data[1]["v2v0"].shape[0] <= 943 + ) + self.assertTrue( + data[1]["v2v1"].shape[0] > 0 and data[1]["v2v1"].shape[0] <= 959 + ) + self.assertTrue( + data[1]["v2v2"].shape[0] > 0 and data[1]["v2v2"].shape[0] <= 966 + ) + + def test_whole_graph_pyg(self): + loader = NeighborLoader( + graph=self.conn, + v_in_feats={"v0": ["x"], "v1": ["x"], "v2": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + num_batches=1, + num_neighbors=10, + num_hops=2, + shuffle=False, + output_format="PyG", + kafka_address="kafka:9092" + ) + data = loader.data + # print(data) + self.assertTupleEqual(data["v0"]["x"].shape, (76, 77)) + self.assertEqual(data["v0"]["y"].shape[0], 76) + self.assertEqual(data["v0"]["train_mask"].shape[0], 76) + self.assertEqual(data["v0"]["test_mask"].shape[0], 76) + self.assertEqual(data["v0"]["val_mask"].shape[0], 76) + self.assertEqual(data["v0"]["is_seed"].shape[0], 76) + self.assertTupleEqual(data["v1"]["x"].shape, (110, 57)) + self.assertEqual(data["v1"]["is_seed"].shape[0], 110) + self.assertTupleEqual(data["v2"]["x"].shape, (100, 48)) + self.assertEqual(data["v2"]["is_seed"].shape[0], 100) + self.assertTrue( + data["v0v0"]["edge_index"].shape[1] > 0 + and data["v0v0"]["edge_index"].shape[1] <= 710 + ) + self.assertTrue( + data["v1v1"]["edge_index"].shape[1] > 0 + and data["v1v1"]["edge_index"].shape[1] <= 1044 + ) + self.assertTrue( + data["v1v2"]["edge_index"].shape[1] > 0 + and data["v1v2"]["edge_index"].shape[1] <= 1038 + ) + self.assertTrue( + data["v2v0"]["edge_index"].shape[1] > 0 + and data["v2v0"]["edge_index"].shape[1] <= 943 + ) + self.assertTrue( + data["v2v1"]["edge_index"].shape[1] > 0 + and data["v2v1"]["edge_index"].shape[1] <= 959 + ) + self.assertTrue( + data["v2v2"]["edge_index"].shape[1] > 0 + and data["v2v2"]["edge_index"].shape[1] <= 966 + ) + + def test_iterate_pyg(self): + loader = NeighborLoader( + graph=self.conn, + v_in_feats={"v0": ["x"], "v1": ["x"], "v2": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + v_seed_types=["v2"], + batch_size=16, + num_neighbors=10, + num_hops=2, + shuffle=False, + output_format="PyG", + kafka_address="kafka:9092" + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + self.assertIsInstance(data, pygHeteroData) + self.assertGreater(data["v2"]["x"].shape[0], 0) + self.assertEqual(data["v2"]["x"].shape[0], data["v2"]["is_seed"].shape[0]) + batch_sizes.append(int(data["v2"]["is_seed"].sum())) + if "v1" in data.node_types: + self.assertGreater(data["v1"]["x"].shape[0], 0) + self.assertEqual(data["v1"]["x"].shape[0], data["v1"]["is_seed"].shape[0]) + if "v2" in data.node_types: + self.assertGreater(data["v0"]["x"].shape[0], 0) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["y"].shape[0]) + self.assertEqual( + data["v0"]["x"].shape[0], data["v0"]["train_mask"].shape[0] + ) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["test_mask"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["is_seed"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["val_mask"].shape[0]) + self.assertTrue( + ('v0', 'v0v0', 'v0') in data.edge_types or + ('v1', 'v1v1', 'v1') in data.edge_types or + ('v1', 'v1v2', 'v2') in data.edge_types or + ('v2', 'v2v0', 'v0') in data.edge_types or + ('v2', 'v2v1', 'v1') in data.edge_types or + ('v2', 'v2v2', 'v2') in data.edge_types) + if ('v0', 'v0v0', 'v0') in data.edge_types: + self.assertTrue( + data['v0', 'v0v0', 'v0']["edge_index"].shape[1] > 0 + and data['v0', 'v0v0', 'v0']["edge_index"].shape[1] <= 710 + ) + if ('v1', 'v1v1', 'v1') in data.edge_types: + self.assertTrue( + data['v1', 'v1v1', 'v1']["edge_index"].shape[1] > 0 + and data['v1', 'v1v1', 'v1']["edge_index"].shape[1] <= 1044 + ) + if ('v1', 'v1v2', 'v2') in data.edge_types: + self.assertTrue( + data['v1', 'v1v2', 'v2']["edge_index"].shape[1] > 0 + and data['v1', 'v1v2', 'v2']["edge_index"].shape[1] <= 1038 + ) + if ('v2', 'v2v0', 'v0') in data.edge_types: + self.assertTrue( + data['v2', 'v2v0', 'v0']["edge_index"].shape[1] > 0 + and data['v2', 'v2v0', 'v0']["edge_index"].shape[1] <= 943 + ) + if ('v2', 'v2v1', 'v1') in data.edge_types: + self.assertTrue( + data['v2', 'v2v1', 'v1']["edge_index"].shape[1] > 0 + and data['v2', 'v2v1', 'v1']["edge_index"].shape[1] <= 959 + ) + if ('v2', 'v2v2', 'v2') in data.edge_types: + self.assertTrue( + data['v2', 'v2v2', 'v2']["edge_index"].shape[1] > 0 + and data['v2', 'v2v2', 'v2']["edge_index"].shape[1] <= 966 + ) + num_batches += 1 + self.assertEqual(num_batches, 7) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) + + def test_iterate_pyg_multichar_delimiter(self): + loader = NeighborLoader( + graph=self.conn, + v_in_feats={"v0": ["x"], "v1": ["x"], "v2": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + v_seed_types=["v2"], + batch_size=16, + num_neighbors=10, + num_hops=2, + shuffle=False, + output_format="PyG", + delimiter="|$", + kafka_address="kafka:9092" + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + self.assertIsInstance(data, pygHeteroData) + self.assertGreater(data["v2"]["x"].shape[0], 0) + self.assertEqual(data["v2"]["x"].shape[0], data["v2"]["is_seed"].shape[0]) + batch_sizes.append(int(data["v2"]["is_seed"].sum())) + if "v1" in data.node_types: + self.assertGreater(data["v1"]["x"].shape[0], 0) + self.assertEqual(data["v1"]["x"].shape[0], data["v1"]["is_seed"].shape[0]) + if "v2" in data.node_types: + self.assertGreater(data["v0"]["x"].shape[0], 0) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["y"].shape[0]) + self.assertEqual( + data["v0"]["x"].shape[0], data["v0"]["train_mask"].shape[0] + ) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["test_mask"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["is_seed"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["val_mask"].shape[0]) + self.assertTrue( + ('v0', 'v0v0', 'v0') in data.edge_types or + ('v1', 'v1v1', 'v1') in data.edge_types or + ('v1', 'v1v2', 'v2') in data.edge_types or + ('v2', 'v2v0', 'v0') in data.edge_types or + ('v2', 'v2v1', 'v1') in data.edge_types or + ('v2', 'v2v2', 'v2') in data.edge_types) + if ('v0', 'v0v0', 'v0') in data.edge_types: + self.assertTrue( + data['v0', 'v0v0', 'v0']["edge_index"].shape[1] > 0 + and data['v0', 'v0v0', 'v0']["edge_index"].shape[1] <= 710 + ) + if ('v1', 'v1v1', 'v1') in data.edge_types: + self.assertTrue( + data['v1', 'v1v1', 'v1']["edge_index"].shape[1] > 0 + and data['v1', 'v1v1', 'v1']["edge_index"].shape[1] <= 1044 + ) + if ('v1', 'v1v2', 'v2') in data.edge_types: + self.assertTrue( + data['v1', 'v1v2', 'v2']["edge_index"].shape[1] > 0 + and data['v1', 'v1v2', 'v2']["edge_index"].shape[1] <= 1038 + ) + if ('v2', 'v2v0', 'v0') in data.edge_types: + self.assertTrue( + data['v2', 'v2v0', 'v0']["edge_index"].shape[1] > 0 + and data['v2', 'v2v0', 'v0']["edge_index"].shape[1] <= 943 + ) + if ('v2', 'v2v1', 'v1') in data.edge_types: + self.assertTrue( + data['v2', 'v2v1', 'v1']["edge_index"].shape[1] > 0 + and data['v2', 'v2v1', 'v1']["edge_index"].shape[1] <= 959 + ) + if ('v2', 'v2v2', 'v2') in data.edge_types: + self.assertTrue( + data['v2', 'v2v2', 'v2']["edge_index"].shape[1] > 0 + and data['v2', 'v2v2', 'v2']["edge_index"].shape[1] <= 966 + ) + num_batches += 1 + self.assertEqual(num_batches, 7) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) + + if __name__ == "__main__": suite = unittest.TestSuite() suite.addTest(TestGDSNeighborLoaderKafka("test_init")) suite.addTest(TestGDSNeighborLoaderKafka("test_iterate_pyg")) suite.addTest(TestGDSNeighborLoaderKafka("test_iterate_stop_pyg")) - suite.addTest(TestGDSNeighborLoaderKafka("test_whole_graph_pyg")) suite.addTest(TestGDSNeighborLoaderKafka("test_edge_attr")) suite.addTest(TestGDSNeighborLoaderKafka("test_distributed_loaders")) # suite.addTest(TestGDSNeighborLoaderKafka("test_sasl_plaintext")) # suite.addTest(TestGDSNeighborLoaderKafka("test_sasl_ssl")) suite.addTest(TestGDSNeighborLoaderREST("test_init")) suite.addTest(TestGDSNeighborLoaderREST("test_iterate_pyg")) - suite.addTest(TestGDSNeighborLoaderREST("test_whole_graph_pyg")) suite.addTest(TestGDSNeighborLoaderREST("test_edge_attr")) suite.addTest(TestGDSNeighborLoaderREST("test_fetch")) suite.addTest(TestGDSNeighborLoaderREST("test_fetch_delimiter")) @@ -890,6 +1079,11 @@ def test_metadata(self): suite.addTest(TestGDSHeteroNeighborLoaderREST("test_fetch")) suite.addTest(TestGDSHeteroNeighborLoaderREST("test_fetch_delimiter")) suite.addTest(TestGDSHeteroNeighborLoaderREST("test_metadata")) + suite.addTest(TestGDSHeteroNeighborLoaderKafka("test_init")) + suite.addTest(TestGDSHeteroNeighborLoaderKafka("test_whole_graph_df")) + suite.addTest(TestGDSHeteroNeighborLoaderKafka("test_whole_graph_pyg")) + suite.addTest(TestGDSHeteroNeighborLoaderKafka("test_iterate_pyg")) + suite.addTest(TestGDSHeteroNeighborLoaderKafka("test_iterate_pyg_multichar_delimiter")) runner = unittest.TextTestRunner(verbosity=2, failfast=True) runner.run(suite) From a78f33d9c90f2eb394647af85bd856f140506ca0 Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Tue, 31 Oct 2023 14:18:53 -0700 Subject: [PATCH 21/36] feat: update all about EdgeNeighborLoader --- pyTigerGraph/gds/dataloaders.py | 199 ++++++------ .../gds/gsql/dataloaders/edge_nei_loader.gsql | 288 +++++------------- .../gsql/dataloaders/edge_nei_loader_sub.gsql | 46 +++ tests/test_gds_EdgeNeighborLoader.py | 184 ++++++++++- 4 files changed, 410 insertions(+), 307 deletions(-) create mode 100644 pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader_sub.gsql diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index 476742fa..3d898829 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -3372,7 +3372,13 @@ def __init__( self.batch_size = math.ceil(num_edges / num_batches) self.num_batches = num_batches if filter_by: - self._payload["filter_by"] = filter_by + if isinstance(filter_by, str): + self._payload["filter_by"] = filter_by + else: + attr = set(filter_by.values()) + if len(attr) != 1: + raise NotImplementedError("Filtering by different attributes for different edge types is not supported. Please use the same attribute for different types.") + self._payload["filter_by"] = attr.pop() self._payload["shuffle"] = shuffle self._payload["v_types"] = self._vtypes self._payload["e_types"] = self._etypes @@ -3700,29 +3706,45 @@ def __init__( self._vtypes = sorted(self._vtypes) self._etypes = sorted(self._etypes) # Resolve seeds - self._seed_types = self._etypes if ((not filter_by) or isinstance(filter_by, str)) else list(filter_by.keys()) - if not(filter_by) and e_seed_types: - if isinstance(e_seed_types, str): - self._seed_types = [e_seed_types] - elif isinstance(e_seed_types, list): + if e_seed_types: + if isinstance(e_seed_types, list): self._seed_types = e_seed_types + elif isinstance(e_seed_types, str): + self._seed_types = [e_seed_types] else: - raise TigerGraphException("e_seed_types must be type list or string.") + raise TigerGraphException("e_seed_types must be either of type list or string.") + elif isinstance(filter_by, dict): + self._seed_types = list(filter_by.keys()) + else: + self._seed_types = self._etypes + if set(self._seed_types) - set(self._etypes): + raise ValueError("Seed type has to be one of the edge types to retrieve") # Resolve number of batches if batch_size: - # If batch_size is given, calculate the number of batches + # batch size takes precedence over number of batches + self.batch_size = batch_size + self.num_batches = None + else: + # If number of batches is given, calculate batch size if filter_by: - num_edges = sum(self._graph.getEdgeStats(e_type)[e_type][filter_by if isinstance(filter_by, str) else filter_by[e_type]]["TRUE"] for e_type in self._seed_types) + num_edges = 0 + for e_type in self._seed_types: + tmp = self._graph.getEdgeStats(e_type)[e_type][filter_by if isinstance(filter_by, str) else filter_by[e_type]]["TRUE"] + if self._e_schema[e_type]["IsDirected"]: + num_edges += tmp + else: + num_edges += 2*tmp else: - num_edges = sum(self._graph.getEdgeCount(i) for i in self._seed_types) - self.num_batches = math.ceil(num_edges / batch_size) - else: - # Otherwise, take the number of batches as is. - self.num_batches = num_batches + num_edges = 0 + for e_type in self._seed_types: + tmp = self._graph.getEdgeCount(e_type) + if self._e_schema[e_type]["IsDirected"]: + num_edges += tmp + else: + num_edges += 2*tmp + self.batch_size = math.ceil(num_edges / num_batches) + self.num_batches = num_batches # Initialize parameters for the query - if batch_size: - self._payload["batch_size"] = batch_size - self._payload["num_batches"] = self.num_batches self._payload["num_neighbors"] = num_neighbors self._payload["num_hops"] = num_hops self._payload["delimiter"] = delimiter @@ -3768,17 +3790,33 @@ def _install_query(self, force: bool = False): + self.v_extra_feats.get(vtype, []) ) v_attr_types = self._v_schema[vtype] - if v_attr_names: - print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + {} + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype, print_attr) - else: - print_query += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype) - print_query += "END" + print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) + print_query += """ + {} s.type == "{}" THEN + @@v_batch += (s.type + delimiter + stringify(getvid(s)) {}+ "\\n")"""\ + .format("IF" if idx==0 else "ELSE IF", vtype, + "+ delimiter + " + print_attr if v_attr_names else "") + print_query += """ + END""" query_replace["{VERTEXATTRS}"] = print_query # Multiple edge types print_query_seed = "" + for idx, etype in enumerate(self._seed_types): + e_attr_names = ( + self.e_in_feats.get(etype, []) + + self.e_out_labels.get(etype, []) + + self.e_extra_feats.get(etype, []) + ) + e_attr_types = self._e_schema[etype] + print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) + print_query_seed += """ + {} e.type == "{}" THEN + @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ delimiter + "1\\n")"""\ + .format("IF" if idx==0 else "ELSE IF", etype, + "+ delimiter + " + print_attr if e_attr_names else "") + print_query_seed += """ + END""" + query_replace["{SEEDEDGEATTRS}"] = print_query_seed print_query_other = "" for idx, etype in enumerate(self._etypes): e_attr_names = ( @@ -3787,52 +3825,36 @@ def _install_query(self, force: bool = False): + self.e_extra_feats.get(etype, []) ) e_attr_types = self._e_schema[etype] - if e_attr_names: - print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) - print_query_seed += '{} e.type == "{}" THEN \n @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + delimiter + "1\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", etype, print_attr) - print_query_other += '{} e.type == "{}" THEN \n @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + delimiter + "0\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", etype, print_attr) - else: - print_query_seed += '{} e.type == "{}" THEN \n @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + "1\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", etype) - print_query_other += '{} e.type == "{}" THEN \n @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + "0\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", etype) - print_query_seed += "END" - print_query_other += "END" - query_replace["{SEEDEDGEATTRS}"] = print_query_seed + print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) + print_query_other += """ + {} e.type == "{}" THEN + @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ delimiter + "0\\n")"""\ + .format("IF" if idx==0 else "ELSE IF", etype, + "+ delimiter + "+ print_attr if e_attr_names else "") + print_query_other += """ + END""" query_replace["{OTHEREDGEATTRS}"] = print_query_other else: # Ignore vertex types v_attr_names = self.v_in_feats + self.v_out_labels + self.v_extra_feats v_attr_types = next(iter(self._v_schema.values())) - if v_attr_names: - print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query = '@@v_batch += (stringify(getvid(s)) + delimiter + {} + "\\n")'.format( - print_attr - ) - query_replace["{VERTEXATTRS}"] = print_query - else: - print_query = '@@v_batch += (stringify(getvid(s)) + "\\n")' - query_replace["{VERTEXATTRS}"] = print_query + print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) + print_query = '@@v_batch += (stringify(getvid(s)) {}+ "\\n")'.format( + "+ delimiter + " + print_attr if v_attr_names else "" + ) + query_replace["{VERTEXATTRS}"] = print_query # Ignore edge types e_attr_names = self.e_in_feats + self.e_out_labels + self.e_extra_feats e_attr_types = next(iter(self._e_schema.values())) - if e_attr_names: - print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) - print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + delimiter + "1\\n")'.format( - print_attr - ) - query_replace["{SEEDEDGEATTRS}"] = print_query - print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + delimiter + "0\\n")'.format( - print_attr - ) - query_replace["{OTHEREDGEATTRS}"] = print_query - else: - print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + "1\\n")' - query_replace["{SEEDEDGEATTRS}"] = print_query - print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + "0\\n")' - query_replace["{OTHEREDGEATTRS}"] = print_query + print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) + print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ delimiter + "1\\n")'.format( + "+ delimiter + " + print_attr if e_attr_names else "" + ) + query_replace["{SEEDEDGEATTRS}"] = print_query + print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ delimiter + "0\\n")'.format( + "+ delimiter + " + print_attr if e_attr_names else "" + ) + query_replace["{OTHEREDGEATTRS}"] = print_query # Install query query_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), @@ -3840,15 +3862,21 @@ def _install_query(self, force: bool = False): "dataloaders", "edge_nei_loader.gsql", ) - return install_query_file(self._graph, query_path, query_replace, force=force, distributed=self.distributed_query) + sub_query_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "gsql", + "dataloaders", + "edge_nei_loader_sub.gsql", + ) + return install_query_files(self._graph, [sub_query_path, query_path], query_replace, force=force, distributed=[False, self.distributed_query]) def _start(self) -> None: # Create task and result queues - self._read_task_q = Queue(self.buffer_size * 2) + self._read_task_q = Queue(self.buffer_size) self._data_q = Queue(self.buffer_size) self._exit_event = Event() - self._start_request(True, "both") + self._start_request(True) # Start reading thread. if not self.is_hetero: @@ -3865,26 +3893,25 @@ def _start(self) -> None: e_attr_types[etype]["is_seed"] = "bool" v_attr_types = self._v_schema self._reader = Thread( - target=self._read_data, - args=( - self._exit_event, - self._read_task_q, - self._data_q, - "graph", - self.output_format, - self.v_in_feats, - self.v_out_labels, - self.v_extra_feats, - v_attr_types, - self.e_in_feats, - self.e_out_labels, - e_extra_feats, - e_attr_types, - self.add_self_loop, - self.delimiter, - True, - self.is_hetero, - self.callback_fn + target=self._read_graph_data, + kwargs=dict( + exit_event = self._exit_event, + in_q = self._read_task_q, + out_q = self._data_q, + batch_size = self.batch_size, + out_format = self.output_format, + v_in_feats = self.v_in_feats, + v_out_labels = self.v_out_labels, + v_extra_feats = self.v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = self.e_in_feats, + e_out_labels = self.e_out_labels, + e_extra_feats = e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = self.add_self_loop, + delimiter = self.delimiter, + is_hetero = self.is_hetero, + callback_fn = self.callback_fn ), ) self._reader.start() diff --git a/pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader.gsql b/pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader.gsql index 2948ad81..3a3fc6ab 100644 --- a/pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader.gsql +++ b/pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader.gsql @@ -1,6 +1,4 @@ CREATE QUERY edge_nei_loader_{QUERYSUFFIX}( - INT batch_size, - INT num_batches=1, INT num_neighbors=10, INT num_hops=2, BOOL shuffle=FALSE, @@ -9,8 +7,9 @@ CREATE QUERY edge_nei_loader_{QUERYSUFFIX}( SET e_types, SET seed_types, STRING delimiter, + INT num_chunks=2, STRING kafka_address="", - STRING kafka_topic, + STRING kafka_topic="", INT kafka_topic_partitions=1, STRING kafka_max_size="104857600", INT kafka_timeout=300000, @@ -49,227 +48,100 @@ CREATE QUERY edge_nei_loader_{QUERYSUFFIX}( sasl_password : SASL password for Kafka. ssl_ca_location: Path to CA certificate for verifying the Kafka broker key. */ - TYPEDEF TUPLE ID_Tuple; - INT num_vertices; - INT kafka_errcode; SumAccum @tmp_id; - SumAccum @@kafka_error; - UINT producer; - MapAccum @@edges_sampled; - SetAccum @valid_v_out; - SetAccum @valid_v_in; - # Initialize Kafka producer - IF kafka_address != "" THEN - producer = init_kafka_producer( - kafka_address, kafka_max_size, security_protocol, - sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, - ssl_certificate_location, ssl_key_location, ssl_key_password, - ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, - sasl_kerberos_keytab, sasl_kerberos_principal); - END; - - # Shuffle vertex ID if needed start = {v_types}; + # Filter seeds if needed + seeds = SELECT s + FROM start:s -(seed_types:e)- v_types:t + WHERE filter_by is NULL OR e.getAttr(filter_by, "BOOL") + POST-ACCUM s.@tmp_id = getvid(s) + POST-ACCUM t.@tmp_id = getvid(t); + # Shuffle vertex ID if needed IF shuffle THEN - num_vertices = start.size(); + INT num_vertices = seeds.size(); res = SELECT s - FROM start:s - POST-ACCUM s.@tmp_id = floor(rand()*num_vertices); - ELSE - res = SELECT s - FROM start:s - POST-ACCUM s.@tmp_id = getvid(s); - END; - - SumAccum @@num_edges; - IF filter_by IS NOT NULL THEN - res = SELECT s - FROM start:s -(seed_types:e)- v_types:t WHERE e.getAttr(filter_by, "BOOL") - ACCUM - IF e.isDirected() THEN # we divide by two later to correct for undirected edges being counted twice, need to count directed edges twice to get correct count - @@num_edges += 2 - ELSE - @@num_edges += 1 - END; - ELSE - res = SELECT s - FROM start:s -(seed_types:e)- v_types:t - ACCUM - IF e.isDirected() THEN # we divide by two later to correct for undirected edges being counted twice, need to count directed edges twice to get correct count - @@num_edges += 2 - ELSE - @@num_edges += 1 - END; - END; - INT batch_s; - IF batch_size IS NULL THEN - batch_s = ceil((@@num_edges/2)/num_batches); - ELSE - batch_s = batch_size; + FROM seeds:s + POST-ACCUM s.@tmp_id = floor(rand()*num_vertices) + LIMIT 1; END; # Generate batches - FOREACH batch_id IN RANGE[0, num_batches-1] DO - SetAccum @@vertices; - SumAccum @@e_batch; - SumAccum @@v_batch; - SetAccum @@printed_edges; - SetAccum @@seeds; - SetAccum @@targets; - HeapAccum (1, tmp_id ASC) @@batch_heap; - @@batch_heap.resize(batch_s); + # If using kafka to export + IF kafka_address != "" THEN + SumAccum @@kafka_error; + + # Initialize Kafka producer + UINT producer = init_kafka_producer( + kafka_address, kafka_max_size, security_protocol, + sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, + ssl_certificate_location, ssl_key_location, ssl_key_password, + ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, + sasl_kerberos_keytab, sasl_kerberos_principal); - start = {v_types}; - IF filter_by IS NOT NULL THEN - res = - SELECT s - FROM start:s -(seed_types:e)- v_types:t - WHERE e.getAttr(filter_by, "BOOL") - AND - ((e.isDirected() AND ((t.@tmp_id >= s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id)) OR - (t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id)))) - OR - (NOT e.isDirected() AND ((t.@tmp_id >= s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id)) OR - (t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id))) - AND ((s.@tmp_id >= t.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id+s.@tmp_id)) OR - (t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id))))) - - ACCUM - IF t.@tmp_id >= s.@tmp_id THEN - @@batch_heap += ID_Tuple(((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id), s, t) - ELSE - @@batch_heap += ID_Tuple(((s.@tmp_id*s.@tmp_id)+t.@tmp_id), s, t) - END; - - FOREACH elem IN @@batch_heap DO - SetAccum @@src; - @@seeds += elem.src; - @@targets += elem.tgt; - @@src += elem.src; - src = {@@src}; - res = SELECT s FROM src:s -(seed_types:e)- v_types:t - WHERE t == elem.tgt - ACCUM - s.@valid_v_out += elem.tgt, - t.@valid_v_in += elem.src; - END; - start = {@@seeds}; - res = - SELECT s - FROM start:s -(seed_types:e)- v_types:t - WHERE t in @@targets AND s IN t.@valid_v_in AND t IN s.@valid_v_out - ACCUM - {SEEDEDGEATTRS}, - @@printed_edges += e, - @@vertices += s, - @@vertices += t, - IF t.@tmp_id >= s.@tmp_id THEN - @@edges_sampled += (((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id) -> TRUE) - ELSE - @@edges_sampled += (((s.@tmp_id*s.@tmp_id)+t.@tmp_id) -> TRUE) - END; - ELSE - res = - SELECT s - FROM start:s -(seed_types:e)- v_types:t - WHERE ((e.isDirected() AND ((t.@tmp_id >= s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id)) OR - (t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id)))) - OR - (NOT e.isDirected() AND ((t.@tmp_id >= s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id)) OR - (t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id))) - AND ((s.@tmp_id >= t.@tmp_id AND NOT @@edges_sampled.containsKey((s.@tmp_id*s.@tmp_id)+t.@tmp_id+s.@tmp_id)) OR - (t.@tmp_id < s.@tmp_id AND NOT @@edges_sampled.containsKey((t.@tmp_id*t.@tmp_id)+s.@tmp_id))))) - - ACCUM - IF t.@tmp_id >= s.@tmp_id THEN - @@batch_heap += ID_Tuple(((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id), s, t) - ELSE - @@batch_heap += ID_Tuple(((s.@tmp_id*s.@tmp_id)+t.@tmp_id), s, t) - END; - - FOREACH elem IN @@batch_heap DO - SetAccum @@src; - @@seeds += elem.src; - @@targets += elem.tgt; - @@src += elem.src; - src = {@@src}; - res = SELECT s FROM src:s -(seed_types:e)- v_types:t - WHERE t == elem.tgt - ACCUM - s.@valid_v_out += elem.tgt, - t.@valid_v_in += elem.src; - END; - start = {@@seeds}; - res = - SELECT s - FROM start:s -(seed_types:e)- v_types:t - WHERE t in @@targets AND s IN t.@valid_v_in AND t IN s.@valid_v_out - ACCUM - {SEEDEDGEATTRS}, - @@printed_edges += e, - @@vertices += s, - @@vertices += t, - IF t.@tmp_id >= s.@tmp_id THEN - @@edges_sampled += (((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id) -> TRUE), - IF NOT e.isDirected() THEN - @@edges_sampled += (((s.@tmp_id*s.@tmp_id)+t.@tmp_id+s.@tmp_id) -> TRUE) - END - ELSE - @@edges_sampled += (((s.@tmp_id*s.@tmp_id)+t.@tmp_id) -> TRUE), - IF NOT e.isDirected() THEN - @@edges_sampled += (((t.@tmp_id*t.@tmp_id)+s.@tmp_id+t.@tmp_id) -> TRUE) - END - END; + FOREACH chunk IN RANGE[0, num_chunks-1] DO + res = SELECT s + FROM seeds:s -(seed_types:e)- v_types:t + WHERE (filter_by is NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk) + ACCUM + STRING e_type = e.type, + LIST msg = edge_nei_loader_sub_{QUERYSUFFIX}(s, t, delimiter, num_hops, num_neighbors, e_types, v_types, e_type), + BOOL is_first=True, + FOREACH i in msg DO + IF is_first THEN + INT kafka_errcode = write_to_kafka(producer, kafka_topic, (getvid(s)+getvid(t))%kafka_topic_partitions, "vertex_batch_" + stringify(getvid(s))+e.type+stringify(getvid(t)), i), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending vertex batch for " + stringify(getvid(s))+e.type+stringify(getvid(t)) + ": "+ stringify(kafka_errcode) + "\\n") + END, + is_first = False + ELSE + INT kafka_errcode = write_to_kafka(producer, kafka_topic, (getvid(s)+getvid(t))%kafka_topic_partitions, "edge_batch_" + stringify(getvid(s))+e.type+stringify(getvid(t)), i), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending edge batch for " + stringify(getvid(s))+e.type+stringify(getvid(t)) + ": "+ stringify(kafka_errcode) + "\\n") + END + END + END + LIMIT 1; END; - - # Get seed vertices - v_in_batch = @@vertices; - seeds = - SELECT s - FROM v_in_batch:s - POST-ACCUM - s.@valid_v_in.clear(), s.@valid_v_out.clear(), - {VERTEXATTRS}; - # Get neighbors of seeeds - FOREACH i IN RANGE[1, num_hops] DO - seeds = SELECT t - FROM seeds:s -(e_types:e)- v_types:t - SAMPLE num_neighbors EDGE WHEN s.outdegree() >= 1 - ACCUM - IF NOT @@printed_edges.contains(e) THEN - {OTHEREDGEATTRS}, - @@printed_edges += e - END; - attr = - SELECT s - FROM seeds:s - POST-ACCUM - IF NOT @@vertices.contains(s) THEN - {VERTEXATTRS}, - @@vertices += s - END; - END; - IF kafka_address != "" THEN - # Write to kafka - kafka_errcode = write_to_kafka(producer, kafka_topic, batch_id%kafka_topic_partitions, "vertex_batch_" + stringify(batch_id), @@v_batch); - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending vertex batch " + stringify(batch_id) + ": "+ stringify(kafka_errcode) + "\n"); - END; - kafka_errcode = write_to_kafka(producer, kafka_topic, batch_id%kafka_topic_partitions, "edge_batch_" + stringify(batch_id), @@e_batch); + + FOREACH i IN RANGE[0, kafka_topic_partitions-1] DO + INT kafka_errcode = write_to_kafka(producer, kafka_topic, i, "STOP", ""); IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending edge batch " + stringify(batch_id) + ": "+ stringify(kafka_errcode) + "\n"); + @@kafka_error += ("Error sending STOP signal to topic partition " + stringify(i) + ": " + stringify(kafka_errcode) + "\n"); END; - ELSE - # Add to response - PRINT @@v_batch AS vertex_batch, @@e_batch AS edge_batch; END; - END; - IF kafka_address != "" THEN - kafka_errcode = close_kafka_producer(producer, kafka_timeout); + + INT kafka_errcode = close_kafka_producer(producer, kafka_timeout); IF kafka_errcode!=0 THEN @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n"); END; PRINT @@kafka_error as kafkaError; + # Else return as http response + ELSE + FOREACH chunk IN RANGE[0, num_chunks-1] DO + MapAccum @@v_batch; + MapAccum @@e_batch; + + res = SELECT s + FROM seeds:s -(seed_types:e)- v_types:t + WHERE (filter_by is NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk) + ACCUM + STRING e_type = e.type, + LIST msg = edge_nei_loader_sub_{QUERYSUFFIX}(s, t, delimiter, num_hops, num_neighbors, e_types, v_types, e_type), + BOOL is_first=True, + FOREACH i in msg DO + IF is_first THEN + @@v_batch += (stringify(getvid(s))+e.type+stringify(getvid(t)) -> i), + is_first = False + ELSE + @@e_batch += (stringify(getvid(s))+e.type+stringify(getvid(t)) -> i) + END + END + LIMIT 1; + + FOREACH (k,v) IN @@v_batch DO + PRINT v as vertex_batch, @@e_batch.get(k) as edge_batch; + END; + END; END; } \ No newline at end of file diff --git a/pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader_sub.gsql b/pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader_sub.gsql new file mode 100644 index 00000000..23f06ff5 --- /dev/null +++ b/pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader_sub.gsql @@ -0,0 +1,46 @@ +CREATE QUERY edge_nei_loader_sub_{QUERYSUFFIX} (VERTEX u, VERTEX v, STRING delimiter, INT num_hops, INT num_neighbors, SET e_types, SET v_types, STRING seed_type) +RETURNS (ListAccum) +SYNTAX V1 +{ + SumAccum @@v_batch; + SumAccum @@e_batch; + SetAccum @@printed_vertices; + SetAccum @@printed_edges; + ListAccum @@ret; + + source = {u}; + res = SELECT s + FROM source:s -(seed_type:e)- v_types:t + WHERE t==v + ACCUM + @@printed_edges += e, + {SEEDEDGEATTRS}; + + start = {u,v}; + res = SELECT s + FROM start:s + POST-ACCUM + @@printed_vertices += s, + {VERTEXATTRS}; + + FOREACH i IN RANGE[1, num_hops] DO + start = SELECT t + FROM start:s -(e_types:e)- v_types:t + SAMPLE num_neighbors EDGE WHEN s.outdegree() >= 1 + ACCUM + IF NOT @@printed_edges.contains(e) THEN + @@printed_edges += e, + {OTHEREDGEATTRS} + END; + start = SELECT s + FROM start:s + POST-ACCUM + IF NOT @@printed_vertices.contains(s) THEN + @@printed_vertices += s, + {VERTEXATTRS} + END; + END; + @@ret += @@v_batch; + @@ret += @@e_batch; + RETURN @@ret; +} diff --git a/tests/test_gds_EdgeNeighborLoader.py b/tests/test_gds_EdgeNeighborLoader.py index a284689a..7c503c54 100644 --- a/tests/test_gds_EdgeNeighborLoader.py +++ b/tests/test_gds_EdgeNeighborLoader.py @@ -2,6 +2,7 @@ from pyTigerGraphUnitTest import make_connection from torch_geometric.data import Data as pygData +from torch_geometric.data import HeteroData as pygHeteroData from pyTigerGraph.gds.dataloaders import EdgeNeighborLoader from pyTigerGraph.gds.utilities import is_query_installed @@ -12,7 +13,7 @@ class TestGDSEdgeNeighborLoaderKafka(unittest.TestCase): def setUpClass(cls): cls.conn = make_connection(graphname="Cora") - def test_iterate_pyg(self): + def test_init(self): loader = EdgeNeighborLoader( graph=self.conn, v_in_feats=["x"], @@ -23,12 +24,26 @@ def test_iterate_pyg(self): shuffle=False, filter_by=None, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, + kafka_address="kafka:9092", + ) + self.assertTrue(is_query_installed(self.conn, loader.query_name)) + self.assertIsNone(loader.num_batches) + + def test_iterate_pyg(self): + loader = EdgeNeighborLoader( + graph=self.conn, + v_in_feats=["x"], + e_extra_feats=["is_train"], + batch_size=1024, + num_neighbors=10, + num_hops=2, + shuffle=True, + filter_by=None, + output_format="PyG", kafka_address="kafka:9092", ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygData) @@ -38,7 +53,11 @@ def test_iterate_pyg(self): self.assertGreater(data["x"].shape[0], 0) self.assertGreater(data["edge_index"].shape[1], 0) num_batches += 1 + batch_sizes.append(int(data["is_seed"].sum())) self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) def test_sasl_ssl(self): loader = EdgeNeighborLoader( @@ -92,12 +111,9 @@ def test_init(self): shuffle=False, filter_by=None, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 11) + self.assertIsNone(loader.num_batches) def test_iterate_pyg(self): loader = EdgeNeighborLoader( @@ -107,14 +123,12 @@ def test_iterate_pyg(self): batch_size=1024, num_neighbors=10, num_hops=2, - shuffle=False, + shuffle=True, filter_by=None, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygData) @@ -124,7 +138,11 @@ def test_iterate_pyg(self): self.assertGreater(data["x"].shape[0], 0) self.assertGreater(data["edge_index"].shape[1], 0) num_batches += 1 + batch_sizes.append(int(data["is_seed"].sum())) self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) def test_iterate_spektral(self): loader = EdgeNeighborLoader( @@ -154,13 +172,153 @@ def test_iterate_spektral(self): self.assertEqual(num_batches, 11) +class TestGDSHeteroEdgeNeighborLoaderREST(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.conn = make_connection(graphname="hetero") + + def test_init(self): + loader = EdgeNeighborLoader( + graph=self.conn, + v_in_feats={"v0": ["x", "y"], "v2": ["x"]}, + e_extra_feats={"v2v0":["is_train"], "v0v0":[], "v2v2":[]}, + e_seed_types=["v2v0"], + batch_size=100, + num_neighbors=5, + num_hops=2, + shuffle=True, + filter_by=None, + output_format="PyG", + ) + self.assertTrue(is_query_installed(self.conn, loader.query_name)) + self.assertIsNone(loader.num_batches) + + def test_iterate_pyg(self): + loader = EdgeNeighborLoader( + graph=self.conn, + v_in_feats={"v0": ["x", "y"], "v2": ["x"]}, + e_extra_feats={"v2v0":["is_train"], "v0v0":[], "v2v2":[]}, + e_seed_types=["v2v0"], + batch_size=100, + num_neighbors=5, + num_hops=2, + shuffle=True, + filter_by=None, + output_format="PyG", + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + self.assertIsInstance(data, pygHeteroData) + self.assertGreater(data["v0"]["x"].shape[0], 0) + self.assertGreater(data["v2"]["x"].shape[0], 0) + self.assertTrue( + data['v2', 'v2v0', 'v0']["edge_index"].shape[1] > 0 + and data['v2', 'v2v0', 'v0']["edge_index"].shape[1] <= 943 + ) + self.assertEqual( + data['v2', 'v2v0', 'v0']["edge_index"].shape[1], + data['v2', 'v2v0', 'v0']["is_train"].shape[0] + ) + if ('v0', 'v0v0', 'v0') in data.edge_types: + self.assertTrue( + data['v0', 'v0v0', 'v0']["edge_index"].shape[1] > 0 + and data['v0', 'v0v0', 'v0']["edge_index"].shape[1] <= 710 + ) + if ('v2', 'v2v2', 'v2') in data.edge_types: + self.assertTrue( + data['v2', 'v2v2', 'v2']["edge_index"].shape[1] > 0 + and data['v2', 'v2v2', 'v2']["edge_index"].shape[1] <= 966 + ) + num_batches += 1 + batch_sizes.append(int(data['v2', 'v2v0', 'v0']["is_seed"].sum())) + self.assertEqual(num_batches, 10) + for i in batch_sizes[:-1]: + self.assertEqual(i, 100) + self.assertLessEqual(batch_sizes[-1], 100) + + +class TestGDSHeteroEdgeNeighborLoaderKafka(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.conn = make_connection(graphname="hetero") + + def test_init(self): + loader = EdgeNeighborLoader( + graph=self.conn, + v_in_feats={"v0": ["x", "y"], "v2": ["x"]}, + e_extra_feats={"v2v0":["is_train"], "v0v0":[], "v2v2":[]}, + e_seed_types=["v2v0"], + batch_size=100, + num_neighbors=5, + num_hops=2, + shuffle=True, + filter_by=None, + output_format="PyG", + kafka_address="kafka:9092" + ) + self.assertTrue(is_query_installed(self.conn, loader.query_name)) + self.assertIsNone(loader.num_batches) + + def test_iterate_pyg(self): + loader = EdgeNeighborLoader( + graph=self.conn, + v_in_feats={"v0": ["x", "y"], "v2": ["x"]}, + e_extra_feats={"v2v0":["is_train"], "v0v0":[], "v2v2":[]}, + e_seed_types=["v2v0"], + batch_size=100, + num_neighbors=5, + num_hops=2, + shuffle=True, + filter_by=None, + output_format="PyG", + kafka_address="kafka:9092" + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + self.assertIsInstance(data, pygHeteroData) + self.assertGreater(data["v0"]["x"].shape[0], 0) + self.assertGreater(data["v2"]["x"].shape[0], 0) + self.assertTrue( + data['v2', 'v2v0', 'v0']["edge_index"].shape[1] > 0 + and data['v2', 'v2v0', 'v0']["edge_index"].shape[1] <= 943 + ) + self.assertEqual( + data['v2', 'v2v0', 'v0']["edge_index"].shape[1], + data['v2', 'v2v0', 'v0']["is_train"].shape[0] + ) + if ('v0', 'v0v0', 'v0') in data.edge_types: + self.assertTrue( + data['v0', 'v0v0', 'v0']["edge_index"].shape[1] > 0 + and data['v0', 'v0v0', 'v0']["edge_index"].shape[1] <= 710 + ) + if ('v2', 'v2v2', 'v2') in data.edge_types: + self.assertTrue( + data['v2', 'v2v2', 'v2']["edge_index"].shape[1] > 0 + and data['v2', 'v2v2', 'v2']["edge_index"].shape[1] <= 966 + ) + num_batches += 1 + batch_sizes.append(int(data['v2', 'v2v0', 'v0']["is_seed"].sum())) + self.assertEqual(num_batches, 10) + for i in batch_sizes[:-1]: + self.assertEqual(i, 100) + self.assertLessEqual(batch_sizes[-1], 100) + + if __name__ == "__main__": suite = unittest.TestSuite() + suite.addTest(TestGDSEdgeNeighborLoaderKafka("test_init")) suite.addTest(TestGDSEdgeNeighborLoaderKafka("test_iterate_pyg")) # suite.addTest(TestGDSEdgeNeighborLoaderKafka("test_sasl_ssl")) suite.addTest(TestGDSEdgeNeighborLoaderREST("test_init")) suite.addTest(TestGDSEdgeNeighborLoaderREST("test_iterate_pyg")) # suite.addTest(TestGDSEdgeNeighborLoaderREST("test_iterate_spektral")) - + suite.addTest(TestGDSHeteroEdgeNeighborLoaderREST("test_init")) + suite.addTest(TestGDSHeteroEdgeNeighborLoaderREST("test_iterate_pyg")) + suite.addTest(TestGDSHeteroEdgeNeighborLoaderKafka("test_init")) + suite.addTest(TestGDSHeteroEdgeNeighborLoaderKafka("test_iterate_pyg")) runner = unittest.TextTestRunner(verbosity=2, failfast=True) runner.run(suite) From c3434b54671d9f3d00c77f0234f44da748d14d47 Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Mon, 6 Nov 2023 13:11:11 -0800 Subject: [PATCH 22/36] feat: add new hgt loader --- pyTigerGraph/gds/dataloaders.py | 272 +++++++++++++----- .../gds/gsql/dataloaders/hgt_loader.gsql | 244 ++++++++-------- .../gds/gsql/dataloaders/hgt_loader_sub.gsql | 32 +++ tests/test_gds_HGTLoader.py | 268 ++++++++++++++--- tests/test_gds_NeighborLoader.py | 2 +- 5 files changed, 582 insertions(+), 236 deletions(-) create mode 100644 pyTigerGraph/gds/gsql/dataloaders/hgt_loader_sub.gsql diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index 3d898829..9dd0735f 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -4636,9 +4636,15 @@ def __init__( self._seed_types = list(filter_by.keys()) else: self._seed_types = self._vtypes + if set(self._seed_types) - set(self._vtypes): + raise ValueError("Seed type has to be one of the vertex types to retrieve") if batch_size: - # If batch_size is given, calculate the number of batches + # batch size takes precedence over number of batches + self.batch_size = batch_size + self.num_batches = None + else: + # If number of batches is given, calculate batch size if not filter_by: num_vertices = sum(self._graph.getVertexCount(self._seed_types).values()) elif isinstance(filter_by, str): @@ -4653,12 +4659,9 @@ def __init__( ) else: raise ValueError("filter_by should be None, attribute name, or dict of {type name: attribute name}.") - self.num_batches = math.ceil(num_vertices / batch_size) - else: - # Otherwise, take the number of batches as is. + self.batch_size = math.ceil(num_vertices / num_batches) self.num_batches = num_batches # Initialize parameters for the query - self._payload["num_batches"] = self.num_batches self._payload["num_hops"] = num_hops if filter_by: if isinstance(filter_by, str): @@ -4668,15 +4671,12 @@ def __init__( if len(attr) != 1: raise NotImplementedError("Filtering by different attributes for different vertex types is not supported. Please use the same attribute for different types.") self._payload["filter_by"] = attr.pop() - if batch_size: - self._payload["batch_size"] = batch_size self._payload["shuffle"] = shuffle self._payload["v_types"] = self._vtypes self._payload["e_types"] = self._etypes self._payload["seed_types"] = self._seed_types self._payload["delimiter"] = self.delimiter self._payload["input_vertices"] = [] - self._payload["num_heap_inserts"] = self.num_heap_inserts # Output self.add_self_loop = add_self_loop # Install query @@ -4701,6 +4701,23 @@ def _install_query(self, force: bool = False): if isinstance(self.v_in_feats, dict) or isinstance(self.e_in_feats, dict): # Multiple vertex types print_query_seed = "" + for idx, vtype in enumerate(self._seed_types): + v_attr_names = ( + self.v_in_feats.get(vtype, []) + + self.v_out_labels.get(vtype, []) + + self.v_extra_feats.get(vtype, []) + ) + v_attr_types = self._v_schema[vtype] + print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) + print_query_seed += """ + {} s.type == "{}" THEN + @@v_batch += (s.type + delimiter + stringify(getvid(s)) {} + delimiter + "1\\n")"""\ + .format("IF" if idx==0 else "ELSE IF", + vtype, + "+ delimiter + " + print_attr if v_attr_names else "") + print_query_seed += """ + END""" + query_replace["{SEEDVERTEXATTRS}"] = print_query_seed print_query_other = "" for idx, vtype in enumerate(self._vtypes): v_attr_names = ( @@ -4709,20 +4726,15 @@ def _install_query(self, force: bool = False): + self.v_extra_feats.get(vtype, []) ) v_attr_types = self._v_schema[vtype] - if v_attr_names: - print_attr = print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query_seed += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + {} + delimiter + "1\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype, print_attr) - print_query_other += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + {} + delimiter + "0\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype, print_attr) - else: - print_query_seed += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + "1\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype) - print_query_other += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + "0\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype) - print_query_seed += "END" - print_query_other += "END" - query_replace["{SEEDVERTEXATTRS}"] = print_query_seed + print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) + print_query_other += """ + {} s.type == "{}" THEN + @@v_batch += (s.type + delimiter + stringify(getvid(s)) {} + delimiter + "0\\n")"""\ + .format("IF" if idx==0 else "ELSE IF", + vtype, + "+ delimiter + " + print_attr if v_attr_names else "") + print_query_other += """ + END""" query_replace["{OTHERVERTEXATTRS}"] = print_query_other # Generate select for each type of neighbors print_select = "" @@ -4741,28 +4753,30 @@ def _install_query(self, force: bool = False): e_attr_types = self._e_schema[etype] if vtype!=e_attr_types["FromVertexTypeName"] and vtype!=e_attr_types["ToVertexTypeName"]: continue - if e_attr_names: - print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) - print_query += '{} e.type == "{}" THEN \n @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + delimiter + {} + "\\n")\n'.format( - "IF" if eidx==0 else "ELSE IF", etype, print_attr) - else: - print_query += '{} e.type == "{}" THEN \n @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) + "\\n")\n'.format( - "IF" if eidx==0 else "ELSE IF", etype) + print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) + print_query += """ + {} e.type == "{}" THEN + @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) {} + "\\n")"""\ + .format("IF" if eidx==0 else "ELSE IF", + etype, + "+ delimiter + " + print_attr if e_attr_names else "") eidx += 1 if print_query: - print_query += "END" - print_select += """seed{} = SELECT t - FROM seeds:s -(e_types:e)- {}:t - SAMPLE {} EDGE WHEN s.outdegree() >= 1 - ACCUM - IF NOT @@printed_edges.contains(e) THEN - @@printed_edges += e, - {} - END; - """.format(vidx, vtype, self.num_neighbors[vtype], print_query) + print_query += """ + END""" + print_select += """ + seed{} = SELECT t + FROM seeds:s -(e_types:e)- {}:t + SAMPLE {} EDGE WHEN s.outdegree() >= 1 + ACCUM + IF NOT @@printed_edges.contains(e) THEN + @@printed_edges += e, + {} + END;""".format(vidx, vtype, self.num_neighbors[vtype], print_query) seeds.append("seed{}".format(vidx)) vidx += 1 - print_select += "seeds = {};".format(" UNION ".join(seeds)) + print_select += """ + seeds = {};""".format(" UNION ".join(seeds)) query_replace["{SELECTNEIGHBORS}"] = print_select # Install query query_path = os.path.join( @@ -4771,15 +4785,21 @@ def _install_query(self, force: bool = False): "dataloaders", "hgt_loader.gsql", ) - return install_query_file(self._graph, query_path, query_replace, force=force, distributed=self.distributed_query) + sub_query_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "gsql", + "dataloaders", + "hgt_loader_sub.gsql", + ) + return install_query_files(self._graph, [sub_query_path, query_path], query_replace, force=force, distributed=[False, self.distributed_query]) def _start(self) -> None: # Create task and result queues - self._read_task_q = Queue(self.buffer_size * 2) + self._read_task_q = Queue(self.buffer_size) self._data_q = Queue(self.buffer_size) self._exit_event = Event() - self._start_request(True, "both") + self._start_request(True) # Start reading thread. if not self.is_hetero: @@ -4796,26 +4816,25 @@ def _start(self) -> None: v_attr_types[vtype]["is_seed"] = "bool" e_attr_types = self._e_schema self._reader = Thread( - target=self._read_data, - args=( - self._exit_event, - self._read_task_q, - self._data_q, - "graph", - self.output_format, - self.v_in_feats, - self.v_out_labels, - v_extra_feats, - v_attr_types, - self.e_in_feats, - self.e_out_labels, - self.e_extra_feats, - e_attr_types, - self.add_self_loop, - self.delimiter, - True, - self.is_hetero, - self.callback_fn + target=self._read_graph_data, + kwargs=dict( + exit_event = self._exit_event, + in_q = self._read_task_q, + out_q = self._data_q, + batch_size = self.batch_size, + out_format = self.output_format, + v_in_feats = self.v_in_feats, + v_out_labels = self.v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = self.e_in_feats, + e_out_labels = self.e_out_labels, + e_extra_feats = self.e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = self.add_self_loop, + delimiter = self.delimiter, + is_hetero = self.is_hetero, + callback_fn = self.callback_fn ), ) self._reader.start() @@ -4852,7 +4871,6 @@ def fetch(self, vertices: list) -> None: _payload = {} _payload["v_types"] = self._payload["v_types"] _payload["e_types"] = self._payload["e_types"] - _payload["num_batches"] = 1 _payload["num_hops"] = self._payload["num_hops"] _payload["delimiter"] = self._payload["delimiter"] _payload["input_vertices"] = [] @@ -4877,11 +4895,15 @@ def fetch(self, vertices: list) -> None: v_attr_types[vtype]["is_seed"] = "bool" v_attr_types[vtype]["primary_id"] = "str" e_attr_types = self._e_schema - i = resp[0] - data = self._parse_data( - raw = (i["vertex_batch"], i["edge_batch"]), - in_format = "graph", - out_format = self.output_format, + vertex_batch = set() + edge_batch = set() + for i in resp: + if "pids" in i: + break + vertex_batch.update(i["vertex_batch"].splitlines()) + edge_batch.update(i["edge_batch"].splitlines()) + data = self._parse_graph_data_to_df( + raw = (vertex_batch, edge_batch), v_in_feats = self.v_in_feats, v_out_labels = self.v_out_labels, v_extra_feats = v_extra_feats, @@ -4890,12 +4912,118 @@ def fetch(self, vertices: list) -> None: e_out_labels = self.e_out_labels, e_extra_feats = self.e_extra_feats, e_attr_types = e_attr_types, - add_self_loop = self.add_self_loop, delimiter = self.delimiter, - reindex = True, primary_id = i["pids"], is_hetero = self.is_hetero, - callback_fn = self.callback_fn ) + if self.output_format == "dataframe" or self.output_format== "df": + vertices, edges = data + if not self.is_hetero: + for column in vertices.columns: + vertices[column] = pd.to_numeric(vertices[column], errors="ignore") + for column in edges.columns: + edges[column] = pd.to_numeric(edges[column], errors="ignore") + else: + for key in vertices: + for column in vertices[key].columns: + vertices[key][column] = pd.to_numeric(vertices[key][column], errors="ignore") + for key in edges: + for column in edges[key].columns: + edges[key][column] = pd.to_numeric(edges[key][column], errors="ignore") + data = (vertices, edges) + elif self.output_format == "pyg": + try: + import torch + except ImportError: + raise ImportError( + "PyTorch is not installed. Please install it to use PyG or DGL output." + ) + try: + import torch_geometric as pyg + except ImportError: + raise ImportError( + "PyG is not installed. Please install PyG to use PyG format." + ) + data = BaseLoader._parse_df_to_pyg( + raw = data, + v_in_feats = self.v_in_feats, + v_out_labels = self.v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = self.e_in_feats, + e_out_labels = self.e_out_labels, + e_extra_feats = self.e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = self.add_self_loop, + is_hetero = self.is_hetero, + torch = torch, + pyg = pyg + ) + elif self.output_format == "dgl": + try: + import torch + except ImportError: + raise ImportError( + "PyTorch is not installed. Please install it to use PyG or DGL output." + ) + try: + import dgl + except ImportError: + raise ImportError( + "DGL is not installed. Please install DGL to use DGL format." + ) + data = BaseLoader._parse_df_to_dgl( + raw = data, + v_in_feats = self.v_in_feats, + v_out_labels = self.v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = self.e_in_feats, + e_out_labels = self.e_out_labels, + e_extra_feats = self.e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = self.add_self_loop, + is_hetero = self.is_hetero, + torch = torch, + dgl= dgl + ) + elif self.output_format == "spektral" and self.is_hetero==False: + try: + import tensorflow as tf + except ImportError: + raise ImportError( + "Tensorflow is not installed. Please install it to use spektral output." + ) + try: + import scipy + except ImportError: + raise ImportError( + "scipy is not installed. Please install it to use spektral output." + ) + try: + import spektral + except ImportError: + raise ImportError( + "Spektral is not installed. Please install it to use spektral output." + ) + data = BaseLoader._parse_df_to_spektral( + raw = data, + v_in_feats = self.v_in_feats, + v_out_labels = self.v_out_labels, + v_extra_feats = v_extra_feats, + v_attr_types = v_attr_types, + e_in_feats = self.e_in_feats, + e_out_labels = self.e_out_labels, + e_extra_feats = self.e_extra_feats, + e_attr_types = e_attr_types, + add_self_loop = self.add_self_loop, + is_hetero = self.is_hetero, + scipy = scipy, + spektral = spektral + ) + else: + raise NotImplementedError + if self.callback_fn: + data = self.callback_fn(data) # Return data return data diff --git a/pyTigerGraph/gds/gsql/dataloaders/hgt_loader.gsql b/pyTigerGraph/gds/gsql/dataloaders/hgt_loader.gsql index 4e43e28f..401a2a6b 100644 --- a/pyTigerGraph/gds/gsql/dataloaders/hgt_loader.gsql +++ b/pyTigerGraph/gds/gsql/dataloaders/hgt_loader.gsql @@ -1,7 +1,5 @@ CREATE QUERY hgt_loader_{QUERYSUFFIX}( SET input_vertices, - INT batch_size, - INT num_batches=1, INT num_hops=2, BOOL shuffle=FALSE, STRING filter_by, @@ -9,8 +7,9 @@ CREATE QUERY hgt_loader_{QUERYSUFFIX}( SET e_types, SET seed_types, STRING delimiter, + INT num_chunks=2, STRING kafka_address="", - STRING kafka_topic, + STRING kafka_topic="", INT kafka_topic_partitions=1, STRING kafka_max_size="104857600", INT kafka_timeout=300000, @@ -25,8 +24,7 @@ CREATE QUERY hgt_loader_{QUERYSUFFIX}( STRING ssl_endpoint_identification_algorithm="", STRING sasl_kerberos_service_name="", STRING sasl_kerberos_keytab="", - STRING sasl_kerberos_principal="", - INT num_heap_inserts = 10 + STRING sasl_kerberos_principal="" ) SYNTAX V1 { /* This query generates the neighborhood subgraphs of given seed vertices (i.e., `input_vertices`). @@ -51,142 +49,138 @@ CREATE QUERY hgt_loader_{QUERYSUFFIX}( sasl_password : SASL password for Kafka. ssl_ca_location: Path to CA certificate for verifying the Kafka broker key. */ - TYPEDEF TUPLE ID_Tuple; - INT num_vertices; - INT kafka_errcode; SumAccum @tmp_id; - SumAccum @@kafka_error; - UINT producer; - INT batch_s; - OrAccum @prev_sampled; - # Initialize Kafka producer - IF kafka_address != "" THEN - producer = init_kafka_producer( - kafka_address, kafka_max_size, security_protocol, - sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, - ssl_certificate_location, ssl_key_location, ssl_key_password, - ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, - sasl_kerberos_keytab, sasl_kerberos_principal); - END; - - # Shuffle vertex ID if needed + # If getting all vertices of given types IF input_vertices.size()==0 THEN start = {seed_types}; - IF filter_by IS NOT NULL THEN - start = SELECT s FROM start:s WHERE s.getAttr(filter_by, "BOOL"); - END; + # Filter seeds if needed + seeds = SELECT s + FROM start:s + WHERE filter_by is NULL OR s.getAttr(filter_by, "BOOL"); + # Shuffle vertex ID if needed IF shuffle THEN - num_vertices = start.size(); + INT num_vertices = seeds.size(); res = SELECT s - FROM start:s - POST-ACCUM s.@tmp_id = floor(rand()*num_vertices); + FROM seeds:s + POST-ACCUM s.@tmp_id = floor(rand()*num_vertices) + LIMIT 1; ELSE res = SELECT s - FROM start:s - POST-ACCUM s.@tmp_id = getvid(s); + FROM seeds:s + POST-ACCUM s.@tmp_id = getvid(s) + LIMIT 1; END; - END; - IF batch_size IS NULL THEN - batch_s = ceil(res.size()/num_batches); - ELSE - batch_s = batch_size; - END; - # Generate subgraphs - FOREACH batch_id IN RANGE[0, num_batches-1] DO - SumAccum @@v_batch; - SumAccum @@e_batch; - SetAccum @@printed_vertices; - SetAccum @@printed_edges; - SetAccum @@seeds; - # Get seeds - IF input_vertices.size()==0 THEN - start = {seed_types}; - HeapAccum (1, tmp_id ASC) @@batch_heap; - @@batch_heap.resize(batch_s); - IF filter_by IS NOT NULL THEN - FOREACH iter IN RANGE[0,num_heap_inserts-1] DO - _verts = SELECT s FROM start:s - WHERE s.@tmp_id % num_heap_inserts == iter AND NOT s.@prev_sampled AND s.getAttr(filter_by, "BOOL") - POST-ACCUM @@batch_heap += ID_Tuple(s.@tmp_id, s); - END; - FOREACH elem IN @@batch_heap DO - @@seeds += elem.v; - END; - seeds = {@@seeds}; - seeds = SELECT s - FROM seeds:s - POST-ACCUM - s.@prev_sampled += TRUE, - {SEEDVERTEXATTRS}, - @@printed_vertices += s; - ELSE - FOREACH iter IN RANGE[0,num_heap_inserts-1] DO - _verts = SELECT s FROM start:s - WHERE s.@tmp_id % num_heap_inserts == iter AND NOT s.@prev_sampled - POST-ACCUM @@batch_heap += ID_Tuple(s.@tmp_id, s); - END; - FOREACH elem IN @@batch_heap DO - @@seeds += elem.v; - END; - seeds = {@@seeds}; - seeds = SELECT s - FROM start:s - POST-ACCUM - s.@prev_sampled += TRUE, - {SEEDVERTEXATTRS}, - @@printed_vertices += s; - END; - ELSE - start = input_vertices; - seeds = SELECT s - FROM start:s - POST-ACCUM - @@printed_vertices += s, - {SEEDVERTEXATTRS}; - END; - # Get neighbors of seeeds - FOREACH i IN RANGE[1, num_hops] DO - {SELECTNEIGHBORS} - attr = SELECT s - FROM seeds:s - POST-ACCUM - IF NOT @@printed_vertices.contains(s) THEN - @@printed_vertices += s, - {OTHERVERTEXATTRS} - END; - END; + # Export data + # If using kafka to export IF kafka_address != "" THEN - # Write to kafka - kafka_errcode = write_to_kafka(producer, kafka_topic, batch_id%kafka_topic_partitions, "vertex_batch_" + stringify(batch_id), @@v_batch); - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending vertex batch " + stringify(batch_id) + ": "+ stringify(kafka_errcode) + "\n"); + SumAccum @@kafka_error; + + # Initialize Kafka producer + UINT producer = init_kafka_producer( + kafka_address, kafka_max_size, security_protocol, + sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, + ssl_certificate_location, ssl_key_location, ssl_key_password, + ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, + sasl_kerberos_keytab, sasl_kerberos_principal); + + FOREACH chunk IN RANGE[0, num_chunks-1] DO + res = SELECT s + FROM seeds:s + WHERE s.@tmp_id % num_chunks == chunk + POST-ACCUM + LIST msg = hgt_loader_sub_{QUERYSUFFIX}(s, delimiter, num_hops, e_types, v_types), + BOOL is_first=True, + FOREACH i in msg DO + IF is_first THEN + INT kafka_errcode = write_to_kafka(producer, kafka_topic, getvid(s)%kafka_topic_partitions, "vertex_batch_" + stringify(getvid(s)), i), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending vertex batch for " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\\n") + END, + is_first = False + ELSE + INT kafka_errcode = write_to_kafka(producer, kafka_topic, getvid(s)%kafka_topic_partitions, "edge_batch_" + stringify(getvid(s)), i), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending edge batch for " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\\n") + END + END + END + LIMIT 1; + END; + + FOREACH i IN RANGE[0, kafka_topic_partitions-1] DO + INT kafka_errcode = write_to_kafka(producer, kafka_topic, i, "STOP", ""); + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending STOP signal to topic partition " + stringify(i) + ": " + stringify(kafka_errcode) + "\n"); + END; END; - kafka_errcode = write_to_kafka(producer, kafka_topic, batch_id%kafka_topic_partitions, "edge_batch_" + stringify(batch_id), @@e_batch); + + INT kafka_errcode = close_kafka_producer(producer, kafka_timeout); IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending edge batch " + stringify(batch_id) + ": "+ stringify(kafka_errcode) + "\n"); + @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n"); END; + PRINT @@kafka_error as kafkaError; + # Else return as http response ELSE - # Add to response - IF input_vertices.size()==0 THEN - PRINT @@v_batch AS vertex_batch, @@e_batch AS edge_batch; - ELSE - MapAccum @@id_map; - source = @@printed_vertices; - res = - SELECT s - FROM source:s - POST-ACCUM @@id_map += (getvid(s) -> s); - PRINT @@v_batch AS vertex_batch, @@e_batch AS edge_batch, @@id_map AS pids; + FOREACH chunk IN RANGE[0, num_chunks-1] DO + MapAccum @@v_batch; + MapAccum @@e_batch; + + res = SELECT s + FROM seeds:s + WHERE s.@tmp_id % num_chunks == chunk + POST-ACCUM + LIST msg = hgt_loader_sub_{QUERYSUFFIX}(s, delimiter, num_hops, e_types, v_types), + BOOL is_first=True, + FOREACH i in msg DO + IF is_first THEN + @@v_batch += (getvid(s) -> i), + is_first = False + ELSE + @@e_batch += (getvid(s) -> i) + END + END + LIMIT 1; + + FOREACH (k,v) IN @@v_batch DO + PRINT v as vertex_batch, @@e_batch.get(k) as edge_batch; + END; END; - END; - END; - IF kafka_address != "" THEN - kafka_errcode = close_kafka_producer(producer, kafka_timeout); - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n"); END; - PRINT @@kafka_error as kafkaError; + # Else get given vertices. + ELSE + MapAccum @@v_batch; + MapAccum @@e_batch; + MapAccum @@id_map; + + seeds = input_vertices; + res = SELECT s + FROM seeds:s + POST-ACCUM + LIST msg = hgt_loader_sub_{QUERYSUFFIX}(s, delimiter, num_hops, e_types, v_types), + BOOL is_first=True, + FOREACH i in msg DO + IF is_first THEN + @@v_batch += (getvid(s) -> i), + is_first = False + ELSE + @@e_batch += (getvid(s) -> i) + END + END, + @@id_map += (getvid(s) -> s) + LIMIT 1; + + FOREACH (k,v) IN @@v_batch DO + PRINT v as vertex_batch, @@e_batch.get(k) as edge_batch; + END; + + FOREACH hop IN RANGE[1, num_hops] DO + seeds = SELECT t + FROM seeds:s -(e_types:e)- v_types:t + POST-ACCUM + @@id_map += (getvid(t) -> t); + END; + PRINT @@id_map AS pids; END; } \ No newline at end of file diff --git a/pyTigerGraph/gds/gsql/dataloaders/hgt_loader_sub.gsql b/pyTigerGraph/gds/gsql/dataloaders/hgt_loader_sub.gsql new file mode 100644 index 00000000..88fa5558 --- /dev/null +++ b/pyTigerGraph/gds/gsql/dataloaders/hgt_loader_sub.gsql @@ -0,0 +1,32 @@ +CREATE QUERY hgt_loader_sub_{QUERYSUFFIX} (VERTEX v, STRING delimiter, INT num_hops, SET e_types, SET v_types) +RETURNS (ListAccum) +SYNTAX V1 +{ + SumAccum @@v_batch; + SumAccum @@e_batch; + SetAccum @@printed_vertices; + SetAccum @@printed_edges; + ListAccum @@ret; + + seeds = {v}; + res = SELECT s + FROM seeds:s + POST-ACCUM + @@printed_vertices += s, + {SEEDVERTEXATTRS}; + + FOREACH i IN RANGE[1, num_hops] DO + {SELECTNEIGHBORS} + + seeds = SELECT s + FROM seeds:s + POST-ACCUM + IF NOT @@printed_vertices.contains(s) THEN + @@printed_vertices += s, + {OTHERVERTEXATTRS} + END; + END; + @@ret += @@v_batch; + @@ret += @@e_batch; + RETURN @@ret; +} diff --git a/tests/test_gds_HGTLoader.py b/tests/test_gds_HGTLoader.py index 15c51820..cfa5f8f3 100644 --- a/tests/test_gds_HGTLoader.py +++ b/tests/test_gds_HGTLoader.py @@ -12,8 +12,6 @@ class TestGDSHGTLoaderREST(unittest.TestCase): @classmethod def setUpClass(cls): cls.conn = make_connection(graphname="hetero") - splitter = cls.conn.gds.vertexSplitter(v_types=["v2"], train_mask=0.3) - splitter.run() def test_init(self): loader = HGTLoader( @@ -26,12 +24,9 @@ def test_init(self): num_hops=2, shuffle=True, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 18) + self.assertIsNone(loader.num_batches) def test_whole_graph_df(self): loader = HGTLoader( @@ -44,14 +39,11 @@ def test_whole_graph_df(self): num_hops=2, shuffle=False, output_format="dataframe", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) data = loader.data - self.assertTupleEqual(data[0]["v0"].shape, (76, 7)) - self.assertTupleEqual(data[0]["v1"].shape, (110, 3)) - self.assertTupleEqual(data[0]["v2"].shape, (100, 3)) + self.assertTupleEqual(data[0]["v0"].shape, (152, 7)) + self.assertTupleEqual(data[0]["v1"].shape, (220, 3)) + self.assertTupleEqual(data[0]["v2"].shape, (200, 3)) self.assertTrue( data[1]["v0v0"].shape[0] > 0 and data[1]["v0v0"].shape[0] <= 710 ) @@ -82,9 +74,6 @@ def test_whole_graph_pyg(self): num_hops=2, shuffle=False, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) data = loader.data self.assertTupleEqual(data["v0"]["x"].shape, (76, 77)) @@ -129,23 +118,59 @@ def test_iterate_pyg(self): v_in_feats={"v0": ["x"], "v1": ["x"], "v2": ["x"]}, v_out_labels={"v0": ["y"]}, v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, - num_batches=6, + v_seed_types=["v2"], + batch_size=16, num_hops=2, shuffle=False, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, - filter_by= {"v2": "train_mask"} ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) self.assertIsInstance(data, pygHeteroData) self.assertGreater(data["v2"]["x"].shape[0], 0) self.assertEqual(data["v2"]["x"].shape[0], data["v2"]["is_seed"].shape[0]) + batch_sizes.append(int(data["v2"]["is_seed"].sum())) + self.assertGreater(data["v1"]["x"].shape[0], 0) + self.assertEqual(data["v1"]["x"].shape[0], data["v1"]["is_seed"].shape[0]) + self.assertGreater(data["v0"]["x"].shape[0], 0) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["y"].shape[0]) + self.assertEqual( + data["v0"]["x"].shape[0], data["v0"]["train_mask"].shape[0] + ) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["test_mask"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["is_seed"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["val_mask"].shape[0]) + self.assertTrue( + data['v2', 'v2v0', 'v0']["edge_index"].shape[1] > 0 + and data['v2', 'v2v0', 'v0']["edge_index"].shape[1] <= 943 + ) + self.assertTrue( + data['v2', 'v2v1', 'v1']["edge_index"].shape[1] > 0 + and data['v2', 'v2v1', 'v1']["edge_index"].shape[1] <= 959 + ) + self.assertTrue( + data['v2', 'v2v2', 'v2']["edge_index"].shape[1] > 0 + and data['v2', 'v2v2', 'v2']["edge_index"].shape[1] <= 966 + ) + self.assertTrue( + data['v0', 'v0v0', 'v0']["edge_index"].shape[1] > 0 + and data['v0', 'v0v0', 'v0']["edge_index"].shape[1] <= 710 + ) + self.assertTrue( + data['v1', 'v1v1', 'v1']["edge_index"].shape[1] > 0 + and data['v1', 'v1v1', 'v1']["edge_index"].shape[1] <= 1044 + ) + self.assertTrue( + data['v1', 'v1v2', 'v2']["edge_index"].shape[1] > 0 + and data['v1', 'v1v2', 'v2']["edge_index"].shape[1] <= 1038 + ) num_batches += 1 - self.assertEqual(num_batches, 6) + self.assertEqual(num_batches, 7) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) def test_fetch(self): loader = HGTLoader( @@ -155,29 +180,192 @@ def test_fetch(self): v_out_labels={"v0": ["y"]}, v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, batch_size=16, - num_hops=1, + num_hops=2, shuffle=False, output_format="PyG", - add_self_loop=False, - loader_id=None, - buffer_size=4, ) data = loader.fetch( - [{"primary_id": "13", "type": "v2"}, {"primary_id": "28", "type": "v2"}] + [{"primary_id": "10", "type": "v0"}, {"primary_id": "55", "type": "v0"}] ) - self.assertIn("13", data["v2"]["primary_id"]) - self.assertIn("28", data["v2"]["primary_id"]) - for i, d in enumerate(data["v2"]["primary_id"]): - if d == "13" or d == "28": - self.assertTrue(data["v2"]["is_seed"][i].item()) + self.assertIn("primary_id", data["v0"]) + self.assertGreater(data["v0"]["x"].shape[0], 2) + self.assertGreater(data["v0v0"]["edge_index"].shape[1], 0) + self.assertIn("10", data["v0"]["primary_id"]) + self.assertIn("55", data["v0"]["primary_id"]) + for i, d in enumerate(data["v0"]["primary_id"]): + if d == "10" or d == "55": + self.assertTrue(data["v0"]["is_seed"][i].item()) else: - self.assertFalse(data["v2"]["is_seed"][i].item()) - # self.assertGreaterEqual(len(data["v0"]["primary_id"]), 2) - # self.assertGreaterEqual(len(data["v1"]["primary_id"]), 2) - # print("v0", data["v0"]["primary_id"]) - # print("v1", data["v1"]["primary_id"]) - # print("v2", data["v2"]["primary_id"]) - # print(data) + self.assertFalse(data["v0"]["is_seed"][i].item()) + + +class TestGDSHGTLoaderKafka(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.conn = make_connection(graphname="hetero") + + def test_init(self): + loader = HGTLoader( + graph=self.conn, + num_neighbors={"v0": 3, "v1": 5, "v2": 10}, + v_in_feats={"v0": ["x"], "v1": ["x"], "v2": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + batch_size=16, + num_hops=2, + shuffle=True, + output_format="PyG", + kafka_address="kafka:9092" + ) + self.assertTrue(is_query_installed(self.conn, loader.query_name)) + self.assertIsNone(loader.num_batches) + + def test_whole_graph_df(self): + loader = HGTLoader( + graph=self.conn, + num_neighbors={"v0": 3, "v1": 5, "v2": 10}, + v_in_feats={"v0": ["x"], "v1": ["x"], "v2": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + num_batches=1, + num_hops=2, + shuffle=False, + output_format="dataframe", + kafka_address="kafka:9092" + ) + data = loader.data + self.assertTupleEqual(data[0]["v0"].shape, (152, 7)) + self.assertTupleEqual(data[0]["v1"].shape, (220, 3)) + self.assertTupleEqual(data[0]["v2"].shape, (200, 3)) + self.assertTrue( + data[1]["v0v0"].shape[0] > 0 and data[1]["v0v0"].shape[0] <= 710 + ) + self.assertTrue( + data[1]["v1v1"].shape[0] > 0 and data[1]["v1v1"].shape[0] <= 1044 + ) + self.assertTrue( + data[1]["v1v2"].shape[0] > 0 and data[1]["v1v2"].shape[0] <= 1038 + ) + self.assertTrue( + data[1]["v2v0"].shape[0] > 0 and data[1]["v2v0"].shape[0] <= 943 + ) + self.assertTrue( + data[1]["v2v1"].shape[0] > 0 and data[1]["v2v1"].shape[0] <= 959 + ) + self.assertTrue( + data[1]["v2v2"].shape[0] > 0 and data[1]["v2v2"].shape[0] <= 966 + ) + + def test_whole_graph_pyg(self): + loader = HGTLoader( + graph=self.conn, + num_neighbors={"v0": 3, "v1": 5, "v2": 10}, + v_in_feats={"v0": ["x"], "v1": ["x"], "v2": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + num_batches=1, + num_hops=2, + shuffle=False, + output_format="PyG", + kafka_address="kafka:9092" + ) + data = loader.data + self.assertTupleEqual(data["v0"]["x"].shape, (76, 77)) + self.assertEqual(data["v0"]["y"].shape[0], 76) + self.assertEqual(data["v0"]["train_mask"].shape[0], 76) + self.assertEqual(data["v0"]["test_mask"].shape[0], 76) + self.assertEqual(data["v0"]["val_mask"].shape[0], 76) + self.assertEqual(data["v0"]["is_seed"].shape[0], 76) + self.assertTupleEqual(data["v1"]["x"].shape, (110, 57)) + self.assertEqual(data["v1"]["is_seed"].shape[0], 110) + self.assertTupleEqual(data["v2"]["x"].shape, (100, 48)) + self.assertEqual(data["v2"]["is_seed"].shape[0], 100) + self.assertTrue( + data["v0v0"]["edge_index"].shape[1] > 0 + and data["v0v0"]["edge_index"].shape[1] <= 710 + ) + self.assertTrue( + data["v1v1"]["edge_index"].shape[1] > 0 + and data["v1v1"]["edge_index"].shape[1] <= 1044 + ) + self.assertTrue( + data["v1v2"]["edge_index"].shape[1] > 0 + and data["v1v2"]["edge_index"].shape[1] <= 1038 + ) + self.assertTrue( + data["v2v0"]["edge_index"].shape[1] > 0 + and data["v2v0"]["edge_index"].shape[1] <= 943 + ) + self.assertTrue( + data["v2v1"]["edge_index"].shape[1] > 0 + and data["v2v1"]["edge_index"].shape[1] <= 959 + ) + self.assertTrue( + data["v2v2"]["edge_index"].shape[1] > 0 + and data["v2v2"]["edge_index"].shape[1] <= 966 + ) + + def test_iterate_pyg(self): + loader = HGTLoader( + graph=self.conn, + num_neighbors={"v0": 2, "v1": 2, "v2": 2}, + v_in_feats={"v0": ["x"], "v1": ["x"], "v2": ["x"]}, + v_out_labels={"v0": ["y"]}, + v_extra_feats={"v0": ["train_mask", "val_mask", "test_mask"]}, + v_seed_types=["v2"], + batch_size=16, + num_hops=2, + shuffle=False, + output_format="PyG", + kafka_address="kafka:9092" + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + self.assertIsInstance(data, pygHeteroData) + self.assertGreater(data["v2"]["x"].shape[0], 0) + self.assertEqual(data["v2"]["x"].shape[0], data["v2"]["is_seed"].shape[0]) + batch_sizes.append(int(data["v2"]["is_seed"].sum())) + self.assertGreater(data["v1"]["x"].shape[0], 0) + self.assertEqual(data["v1"]["x"].shape[0], data["v1"]["is_seed"].shape[0]) + self.assertGreater(data["v0"]["x"].shape[0], 0) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["y"].shape[0]) + self.assertEqual( + data["v0"]["x"].shape[0], data["v0"]["train_mask"].shape[0] + ) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["test_mask"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["is_seed"].shape[0]) + self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["val_mask"].shape[0]) + self.assertTrue( + data['v2', 'v2v0', 'v0']["edge_index"].shape[1] > 0 + and data['v2', 'v2v0', 'v0']["edge_index"].shape[1] <= 943 + ) + self.assertTrue( + data['v2', 'v2v1', 'v1']["edge_index"].shape[1] > 0 + and data['v2', 'v2v1', 'v1']["edge_index"].shape[1] <= 959 + ) + self.assertTrue( + data['v2', 'v2v2', 'v2']["edge_index"].shape[1] > 0 + and data['v2', 'v2v2', 'v2']["edge_index"].shape[1] <= 966 + ) + self.assertTrue( + data['v0', 'v0v0', 'v0']["edge_index"].shape[1] > 0 + and data['v0', 'v0v0', 'v0']["edge_index"].shape[1] <= 710 + ) + self.assertTrue( + data['v1', 'v1v1', 'v1']["edge_index"].shape[1] > 0 + and data['v1', 'v1v1', 'v1']["edge_index"].shape[1] <= 1044 + ) + self.assertTrue( + data['v1', 'v1v2', 'v2']["edge_index"].shape[1] > 0 + and data['v1', 'v1v2', 'v2']["edge_index"].shape[1] <= 1038 + ) + num_batches += 1 + self.assertEqual(num_batches, 7) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) if __name__ == "__main__": @@ -187,6 +375,10 @@ def test_fetch(self): suite.addTest(TestGDSHGTLoaderREST("test_whole_graph_pyg")) suite.addTest(TestGDSHGTLoaderREST("test_iterate_pyg")) suite.addTest(TestGDSHGTLoaderREST("test_fetch")) + suite.addTest(TestGDSHGTLoaderKafka("test_init")) + suite.addTest(TestGDSHGTLoaderKafka("test_whole_graph_df")) + suite.addTest(TestGDSHGTLoaderKafka("test_whole_graph_pyg")) + suite.addTest(TestGDSHGTLoaderKafka("test_iterate_pyg")) runner = unittest.TextTestRunner(verbosity=2, failfast=True) runner.run(suite) diff --git a/tests/test_gds_NeighborLoader.py b/tests/test_gds_NeighborLoader.py index fa509a91..2a3f4ac8 100644 --- a/tests/test_gds_NeighborLoader.py +++ b/tests/test_gds_NeighborLoader.py @@ -586,7 +586,7 @@ def test_iterate_pyg(self): if "v1" in data.node_types: self.assertGreater(data["v1"]["x"].shape[0], 0) self.assertEqual(data["v1"]["x"].shape[0], data["v1"]["is_seed"].shape[0]) - if "v2" in data.node_types: + if "v0" in data.node_types: self.assertGreater(data["v0"]["x"].shape[0], 0) self.assertEqual(data["v0"]["x"].shape[0], data["v0"]["y"].shape[0]) self.assertEqual( From 001e1ea6a91a861314056e2f094298ca7c2041d4 Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Thu, 9 Nov 2023 15:27:20 -0800 Subject: [PATCH 23/36] feat(gds): update nodepieceloader --- pyTigerGraph/gds/dataloaders.py | 163 ++++++++------ .../gsql/dataloaders/nodepiece_loader.gsql | 211 ++++++++---------- tests/test_gds_NodePieceLoader.py | 171 ++++++++++---- 3 files changed, 315 insertions(+), 230 deletions(-) diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index 9dd0735f..ae5f78ed 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -703,8 +703,8 @@ def _read_graph_data( "Spektral is not installed. Please install it to use spektral output." ) # Get raw data from queue and parse - vertex_buffer = set() - edge_buffer = set() + vertex_buffer = [] + edge_buffer = [] buffer_size = 0 is_empty = False last_batch = False @@ -718,14 +718,16 @@ def _read_graph_data( if buffer_size > 0: last_batch = True else: - vertex_buffer.update(raw[0].splitlines()) - edge_buffer.update(raw[1].splitlines()) + vertex_buffer.extend(raw[0].splitlines()) + edge_buffer.extend(raw[1].splitlines()) buffer_size += 1 if (buffer_size < batch_size) and (not last_batch): continue try: + vertex_buffer_d = dict.fromkeys(vertex_buffer) + edge_buffer_d = dict.fromkeys(edge_buffer) data = BaseLoader._parse_graph_data_to_df( - raw = (vertex_buffer, edge_buffer), + raw = (vertex_buffer_d.keys(), edge_buffer_d.keys()), v_in_feats = v_in_feats, v_out_labels = v_out_labels, v_extra_feats = v_extra_feats, @@ -4084,7 +4086,12 @@ def __init__( else: self._seed_types = self._vtypes self._target_v_types = self._vtypes + if batch_size: + # batch size takes precedence over number of batches + self.batch_size = batch_size + self.num_batches = None + else: if not filter_by: num_vertices = sum(self._graph.getVertexCount(self._seed_types).values()) elif isinstance(filter_by, str): @@ -4100,12 +4107,9 @@ def __init__( ) else: raise ValueError("filter_by should be None, attribute name, or dict of {type name: attribute name}.") - self.num_batches = math.ceil(num_vertices / batch_size) - else: - # Otherwise, take the number of batches as is. + self.batch_size = math.ceil(num_vertices / num_batches) self.num_batches = num_batches self.filter_by = filter_by - self._payload["num_batches"] = self.num_batches if filter_by: if isinstance(filter_by, str): self._payload["filter_by"] = filter_by @@ -4113,8 +4117,7 @@ def __init__( attr = set(filter_by.values()) if len(attr) != 1: raise NotImplementedError("Filtering by different attributes for different vertex types is not supported. Please use the same attribute for different types.") - if batch_size: - self._payload["batch_size"] = batch_size + self._payload["filter_by"] = attr.pop() self._payload["shuffle"] = shuffle self._payload["v_types"] = self._vtypes self._payload["seed_types"] = self._seed_types @@ -4126,7 +4129,6 @@ def __init__( self._payload["clear_cache"] = clear_cache self._payload["delimiter"] = delimiter self._payload["input_vertices"] = [] - self._payload["num_heap_inserts"] = self.num_heap_inserts self._payload["num_edge_batches"] = self.num_edge_batches if e_types: self._payload["e_types"] = e_types @@ -4141,7 +4143,7 @@ def __init__( for v_type in self._vtypes: if anchor_attribute not in self._v_schema[v_type].keys(): to_change.append(v_type) - if to_change != []: + if to_change: print("Adding anchor attribute") ret = add_attribute(self._graph, "VERTEX", "BOOL", anchor_attribute, to_change, global_change=global_schema_change) print(ret) @@ -4152,7 +4154,7 @@ def __init__( if anchor_cache_attr not in self._v_schema[v_type].keys(): # add anchor cache attribute to_change.append(v_type) - if to_change != []: + if to_change: print("Adding anchor cache attribute") ret = add_attribute(self._graph, "VERTEX", "MAP", anchor_cache_attr, to_change, global_change=global_schema_change) print(ret) @@ -4234,34 +4236,49 @@ def _install_query(self, force: bool = False) -> str: if isinstance(self.attributes, dict): # Multiple vertex types + print_query_kafka = "" + print_query_http = "" print_query = "" for idx, vtype in enumerate(self._seed_types): v_attr_names = self.attributes.get(vtype, []) query_suffix.extend(v_attr_names) v_attr_types = self._v_schema[vtype] - if v_attr_names: - print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + s.@rel_context_set + delimiter + s.@ancs + delimiter + {} + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype, print_attr) - else: - print_query += '{} s.type == "{}" THEN \n @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + s.@rel_context_set + delimiter + s.@ancs + "\\n")\n'.format( - "IF" if idx==0 else "ELSE IF", vtype) - print_query += "END" - query_replace["{VERTEXATTRS}"] = print_query + print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) + print_query_http += """ + {} s.type == "{}" THEN + @@v_batch += (s.type + delimiter + stringify(getvid(s)) + delimiter + s.@rel_context_set + delimiter + s.@ancs {}+ "\\n")"""\ + .format("IF" if idx==0 else "ELSE IF", vtype, + "+ delimiter + " + print_attr if v_attr_names else "") + print_query_kafka += """ + {} s.type == "{}" THEN + STRING msg = (s.type + delimiter + stringify(getvid(s)) + delimiter + s.@rel_context_set + delimiter + s.@ancs {}+ "\\n"), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, getvid(s)%kafka_topic_partitions, "vertex_" + stringify(getvid(s)), msg), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending data for vertex " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\\n") + END""".format("IF" if idx==0 else "ELSE IF", vtype, + "+ delimiter + " + print_attr if v_attr_names else "") + print_query_http += """ + END""" + print_query_kafka += """ + END""" query_suffix = list(dict.fromkeys(query_suffix)) else: # Ignore vertex types v_attr_names = self.attributes query_suffix.extend(v_attr_names) v_attr_types = next(iter(self._v_schema.values())) - if v_attr_names: - print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query = '@@v_batch += (stringify(getvid(s)) + delimiter + s.@rel_context_set + delimiter + s.@ancs + delimiter + {} + "\\n")'.format( - print_attr - ) - else: - print_query = '@@v_batch += (stringify(getvid(s)) + delimiter + s.@rel_context_set + delimiter + s.@ancs + "\\n")' - query_replace["{VERTEXATTRS}"] = print_query + print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) + print_query_http = '@@v_batch += (stringify(getvid(s)) + delimiter + s.@rel_context_set + delimiter + s.@ancs {}+ "\\n")'.format( + "+ delimiter + " + print_attr if v_attr_names else "" + ) + print_query_kafka = """ + STRING msg = (stringify(getvid(s)) + delimiter + s.@rel_context_set + delimiter + s.@ancs {}+ "\\n"), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, getvid(s)%kafka_topic_partitions, "vertex_" + stringify(getvid(s)), msg), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending data for vertex " + stringify(getvid(s)) + ": "+ stringify(kafka_errcode) + "\\n") + END""".format("+ delimiter + " + print_attr if v_attr_names else "") + query_replace["{VERTEXATTRSHTTP}"] = print_query_http + query_replace["{VERTEXATTRSKAFKA}"] = print_query_kafka md5 = hashlib.md5() query_suffix.extend([self.distributed_query]) md5.update(json.dumps(query_suffix).encode()) @@ -4305,6 +4322,7 @@ def processRelContext(row): context = [self.idToIdx[str(x)] for x in context][:self._payload["max_rel_context"]] context = context + [self.idToIdx["PAD"] for x in range(len(context), self._payload["max_rel_context"])] return context + def processAnchors(row): try: ancs = row.split(" ")[:-1] @@ -4319,6 +4337,7 @@ def processAnchors(row): dists += [self.idToIdx["PAD"] for x in range(len(dists), self._payload["max_anchors"])] toks += [self.idToIdx["PAD"] for x in range(len(toks), self._payload["max_anchors"])] return {"ancs":toks, "dists": dists} + if self.is_hetero: for v_type in data.keys(): data[v_type]["relational_context"] = data[v_type]["relational_context"].apply(lambda x: processRelContext(x)) @@ -4341,11 +4360,11 @@ def processAnchors(row): def _start(self) -> None: # Create task and result queues - self._read_task_q = Queue(self.buffer_size * 2) + self._read_task_q = Queue(self.buffer_size) self._data_q = Queue(self.buffer_size) self._exit_event = Event() - self._start_request(False, "vertex") + self._start_request(False) # Start reading thread. if not self.is_hetero: @@ -4353,26 +4372,19 @@ def _start(self) -> None: else: v_attr_types = self._v_schema self._reader = Thread( - target=self._read_data, - args=( - self._exit_event, - self._read_task_q, - self._data_q, - "vertex", - self.output_format, - self.attributes, - {} if self.is_hetero else [], - {} if self.is_hetero else [], - v_attr_types, - [], - [], - [], - {}, - False, - self.delimiter, - False, - self.is_hetero, - self.nodepiece_process + target=self._read_vertex_data, + kwargs=dict( + exit_event = self._exit_event, + in_q = self._read_task_q, + out_q = self._data_q, + batch_size = self.batch_size, + v_in_feats = self.attributes, + v_out_labels = {} if self.is_hetero else [], + v_extra_feats = {} if self.is_hetero else [], + v_attr_types = v_attr_types, + delimiter = self.delimiter, + is_hetero = self.is_hetero, + callback_fn = self.nodepiece_process ), ) self._reader.start() @@ -4426,28 +4438,28 @@ def fetch(self, vertices: list) -> None: v_attr_types = next(iter(self._v_schema.values())) else: v_attr_types = self._v_schema - if self.is_hetero: - data = self._parse_data(resp[0]["vertex_batch"], - v_in_feats=attributes, - v_out_labels = {}, - v_extra_feats = {}, - v_attr_types=v_attr_types, - reindex=False, - delimiter = self.delimiter, - is_hetero=self.is_hetero, - primary_id=resp[0]["pids"], - callback_fn=self.nodepiece_process) + vertex_batch = set() + for i in resp: + if "pids" in i: + break + vertex_batch.add(i["data_batch"]) + data = BaseLoader._parse_vertex_data( + raw = vertex_batch, + v_in_feats = attributes, + v_out_labels = {} if self.is_hetero else [], + v_extra_feats = {} if self.is_hetero else [], + v_attr_types = v_attr_types, + delimiter = self.delimiter, + is_hetero = self.is_hetero + ) + if not self.is_hetero: + for column in data.columns: + data[column] = pd.to_numeric(data[column], errors="ignore") else: - data = self._parse_data(resp[0]["vertex_batch"], - v_in_feats=attributes, - v_out_labels = [], - v_extra_feats = [], - v_attr_types=v_attr_types, - reindex=False, - delimiter = self.delimiter, - is_hetero=self.is_hetero, - primary_id=resp[0]["pids"], - callback_fn=self.nodepiece_process) + for key in data: + for column in data[key].columns: + data[key][column] = pd.to_numeric(data[key][column], errors="ignore") + data = self.nodepiece_process(data) return data def precompute(self) -> None: @@ -4902,6 +4914,7 @@ def fetch(self, vertices: list) -> None: break vertex_batch.update(i["vertex_batch"].splitlines()) edge_batch.update(i["edge_batch"].splitlines()) + print(len(vertex_batch), len(edge_batch)) data = self._parse_graph_data_to_df( raw = (vertex_batch, edge_batch), v_in_feats = self.v_in_feats, @@ -4916,6 +4929,8 @@ def fetch(self, vertices: list) -> None: primary_id = i["pids"], is_hetero = self.is_hetero, ) + print(data[0]) + print(data[1]) if self.output_format == "dataframe" or self.output_format== "df": vertices, edges = data if not self.is_hetero: diff --git a/pyTigerGraph/gds/gsql/dataloaders/nodepiece_loader.gsql b/pyTigerGraph/gds/gsql/dataloaders/nodepiece_loader.gsql index 36f0ee1f..1c99da63 100644 --- a/pyTigerGraph/gds/gsql/dataloaders/nodepiece_loader.gsql +++ b/pyTigerGraph/gds/gsql/dataloaders/nodepiece_loader.gsql @@ -12,12 +12,11 @@ CREATE QUERY nodepiece_loader_{QUERYSUFFIX}( INT max_distance, INT max_anchors, INT max_rel_context, - INT batch_size, - INT num_batches=1, BOOL shuffle=FALSE, STRING delimiter, + INT num_chunks=2, STRING kafka_address="", - STRING kafka_topic, + STRING kafka_topic="", INT kafka_topic_partitions=1, STRING kafka_max_size="104857600", INT kafka_timeout=300000, @@ -33,50 +32,31 @@ CREATE QUERY nodepiece_loader_{QUERYSUFFIX}( STRING sasl_kerberos_service_name="", STRING sasl_kerberos_keytab="", STRING sasl_kerberos_principal="", - INT num_heap_inserts=10, INT num_edge_batches=10 ) SYNTAX v1{ TYPEDEF TUPLE Distance_Tuple; - TYPEDEF TUPLE ID_Tuple; INT num_vertices; - INT kafka_errcode; - INT batch_s; SumAccum @tmp_id; - SumAccum @@kafka_error; - SetAccum @next_pass, @to_pass, @received; HeapAccum (max_anchors, distance ASC) @token_heap; SumAccum @rel_context_set; SumAccum @ancs; - OrAccum @prev_sampled; OrAccum @heapFull; - MapAccum> @@token_count; MapAccum> @conv_map; BOOL cache_empty = FALSE; INT distance; - UINT producer; - - # Initialize Kafka producer - IF kafka_address != "" THEN - producer = init_kafka_producer( - kafka_address, kafka_max_size, security_protocol, - sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, - ssl_certificate_location, ssl_key_location, ssl_key_password, - ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, - sasl_kerberos_keytab, sasl_kerberos_principal); - END; start = {v_types}; # Perform fetch operation if desired IF clear_cache THEN - res = SELECT s FROM start:s POST-ACCUM s.{ANCHOR_CACHE_ATTRIBUTE} = s.@conv_map; + res = SELECT s FROM start:s POST-ACCUM s.{ANCHOR_CACHE_ATTRIBUTE} = s.@conv_map; END; IF input_vertices.size() != 0 AND NOT compute_all THEN - seeds = {input_vertices}; - res = SELECT s FROM seeds:s -(e_types)- v_types:t + seeds = {input_vertices}; + res = SELECT s FROM seeds:s -(e_types)- v_types:t ACCUM IF s.{ANCHOR_CACHE_ATTRIBUTE}.size() != 0 THEN - FOREACH (key, val) IN s.{ANCHOR_CACHE_ATTRIBUTE} DO # s.{ANCHOR_CACHE_ATTRIBUTE} should be changed to getAttr() when supported + FOREACH (key, val) IN s.{ANCHOR_CACHE_ATTRIBUTE} DO # s.ANCHOR_CACHE_ATTRIBUTE should be changed to getAttr() when supported s.@token_heap += Distance_Tuple(key, val) END ELSE @@ -91,6 +71,7 @@ CREATE QUERY nodepiece_loader_{QUERYSUFFIX}( ELSE cache_empty = TRUE; END; + IF cache_empty THEN # computing all, shuffle vertices if needed ancs = SELECT s FROM start:s @@ -127,107 +108,105 @@ CREATE QUERY nodepiece_loader_{QUERYSUFFIX}( END; END; END; - IF batch_size IS NULL THEN - batch_s = ceil(res.size()/num_batches); - ELSE - batch_s = batch_size; - END; - FOREACH batch_id IN RANGE[0, num_batches-1] DO - SumAccum @@v_batch; - SetAccum @@printed_vertices; - SetAccum @@seeds; - # Get batch seeds - IF input_vertices.size()==0 THEN + + # Get batch seeds + IF input_vertices.size()==0 THEN start = {seed_types}; - HeapAccum (1, tmp_id ASC) @@batch_heap; - @@batch_heap.resize(batch_s); - IF filter_by IS NOT NULL THEN - FOREACH iter IN RANGE[0,num_heap_inserts-1] DO - _verts = SELECT s FROM start:s - WHERE s.@tmp_id % num_heap_inserts == iter AND NOT s.@prev_sampled AND s.getAttr(filter_by, "BOOL") - POST-ACCUM @@batch_heap += ID_Tuple(s.@tmp_id, s); - END; - FOREACH elem IN @@batch_heap DO - @@seeds += elem.v; - END; - seeds = {@@seeds}; - seeds = SELECT s - FROM seeds:s - POST-ACCUM - s.@prev_sampled += TRUE, - @@printed_vertices += s; - ELSE - FOREACH iter IN RANGE[0,num_heap_inserts-1] DO - _verts = SELECT s FROM start:s - WHERE s.@tmp_id % num_heap_inserts == iter AND NOT s.@prev_sampled - POST-ACCUM @@batch_heap += ID_Tuple(s.@tmp_id, s); - END; - FOREACH elem IN @@batch_heap DO - @@seeds += elem.v; - END; - seeds = {@@seeds}; - seeds = SELECT s - FROM seeds:s - POST-ACCUM - s.@prev_sampled += TRUE, - @@printed_vertices += s; - END; - ELSE + # Filter seeds if needed + seeds = SELECT s + FROM start:s + WHERE filter_by is NULL OR s.getAttr(filter_by, "BOOL"); + ELSE start = input_vertices; seeds = SELECT s - FROM start:s - ACCUM @@printed_vertices += s; - END; - # Get relational context - - IF max_rel_context > 0 THEN + FROM start:s; + END; + + # Get relational context + IF max_rel_context > 0 THEN seeds = SELECT s FROM seeds:s -(e_types:e)- v_types:t - SAMPLE max_rel_context EDGE WHEN s.outdegree() >= max_rel_context - ACCUM s.@rel_context_set += e.type +" "; - END; + SAMPLE max_rel_context EDGE WHEN s.outdegree() >= max_rel_context + ACCUM s.@rel_context_set += e.type +" "; + END; - res = SELECT s FROM seeds:s - POST-ACCUM - FOREACH tup IN s.@token_heap DO + res = SELECT s + FROM seeds:s + POST-ACCUM + FOREACH tup IN s.@token_heap DO s.@ancs += stringify(tup.v_id)+":"+stringify(tup.distance)+" ", IF use_cache AND cache_empty THEN - s.@conv_map += (tup.v_id -> tup.distance) + s.@conv_map += (tup.v_id -> tup.distance) END - END, - IF (use_cache AND cache_empty) OR precompute THEN + END, + IF (use_cache AND cache_empty) OR precompute THEN s.{ANCHOR_CACHE_ATTRIBUTE} = s.@conv_map - END, - {VERTEXATTRS}; - IF NOT precompute THEN # No Output if precomputing - IF kafka_address != "" THEN - # Write to kafka - kafka_errcode = write_to_kafka(producer, kafka_topic, batch_id%kafka_topic_partitions, "vertex_batch_" + stringify(batch_id), @@v_batch); - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending vertex batch " + stringify(batch_id) + ": "+ stringify(kafka_errcode) + "\n"); - END; - ELSE # HTTP mode - # Add to response - IF input_vertices.size()==0 THEN - PRINT @@v_batch AS vertex_batch; - ELSE + END; + + IF NOT precompute THEN # No Output if precomputing + # If getting all vertices of given types + IF input_vertices.size()==0 THEN + IF kafka_address != "" THEN + SumAccum @@kafka_error; + # Initialize Kafka producer + UINT producer = init_kafka_producer( + kafka_address, kafka_max_size, security_protocol, + sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, + ssl_certificate_location, ssl_key_location, ssl_key_password, + ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, + sasl_kerberos_keytab, sasl_kerberos_principal); + + FOREACH chunk IN RANGE[0, num_chunks-1] DO + res = SELECT s FROM seeds:s + WHERE s.@tmp_id % num_chunks == chunk + POST-ACCUM + {VERTEXATTRSKAFKA} + LIMIT 1; + END; + + FOREACH i IN RANGE[0, kafka_topic_partitions-1] DO + INT kafka_errcode = write_to_kafka(producer, kafka_topic, i, "STOP", ""); + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending STOP signal to topic partition " + stringify(i) + ": " + stringify(kafka_errcode) + "\n"); + END; + END; + + INT kafka_errcode = close_kafka_producer(producer, kafka_timeout); + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n"); + END; + + PRINT @@kafka_error as kafkaError; + ELSE # HTTP mode + FOREACH chunk IN RANGE[0, num_chunks-1] DO + ListAccum @@v_batch; + res = SELECT s + FROM seeds:s + WHERE s.@tmp_id % num_chunks == chunk + POST-ACCUM + {VERTEXATTRSHTTP} + LIMIT 1; + + FOREACH i IN @@v_batch DO + PRINT i as data_batch; + END; + END; + END; + ELSE # Else get given vertices + ListAccum @@v_batch; MapAccum @@id_map; MapAccum @@type_map; - source = @@printed_vertices; - res = - SELECT s - FROM source:s - POST-ACCUM @@id_map += (getvid(s) -> s), @@type_map += (getvid(s) -> s.type); - PRINT @@v_batch AS vertex_batch, @@id_map AS pids, @@type_map AS types; - END; - END; - END; - END; - - IF kafka_address != "" THEN - kafka_errcode = close_kafka_producer(producer, kafka_timeout); - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n"); + + res = SELECT s + FROM seeds:s + POST-ACCUM + {VERTEXATTRSHTTP}, + @@id_map += (getvid(s) -> s), + @@type_map += (getvid(s) -> s.type); + + FOREACH i IN @@v_batch DO + PRINT i as data_batch; + END; + PRINT @@id_map AS pids, @@type_map AS types; END; - PRINT @@kafka_error as kafkaError; END; } \ No newline at end of file diff --git a/tests/test_gds_NodePieceLoader.py b/tests/test_gds_NodePieceLoader.py index e34e7a84..8074f990 100644 --- a/tests/test_gds_NodePieceLoader.py +++ b/tests/test_gds_NodePieceLoader.py @@ -8,7 +8,7 @@ from pyTigerGraph.gds.utilities import is_query_installed -class TestGDSNodePieceLoader(unittest.TestCase): +class TestGDSNodePieceLoaderKafka(unittest.TestCase): @classmethod def setUpClass(cls): cls.conn = make_connection(graphname="Cora") @@ -18,16 +18,14 @@ def test_init(self): graph=self.conn, v_feats=["x", "y", "train_mask", "val_mask", "test_mask"], compute_anchors=True, - anchor_percentage=0.5, batch_size=16, shuffle=True, filter_by="train_mask", - loader_id=None, - buffer_size=4, - kafka_address="kafka:9092", + anchor_percentage=0.5, + kafka_address="kafka:9092" ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 9) + self.assertIsNone(loader.num_batches) def test_iterate(self): loader = NodePieceLoader( @@ -38,13 +36,12 @@ def test_iterate(self): shuffle=True, filter_by="train_mask", anchor_percentage=0.5, - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) num_batches = 0 + batch_sizes = [] for data in loader: - # print(num_batches, data.head()) + # print(num_batches, data.shape, data.head()) self.assertIsInstance(data, DataFrame) self.assertIn("x", data.columns) self.assertIn("y", data.columns) @@ -53,8 +50,12 @@ def test_iterate(self): self.assertIn("train_mask", data.columns) self.assertIn("val_mask", data.columns) self.assertIn("test_mask", data.columns) + batch_sizes.append(data.shape[0]) num_batches += 1 self.assertEqual(num_batches, 9) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) def test_all_vertices(self): loader = NodePieceLoader( @@ -64,8 +65,6 @@ def test_all_vertices(self): shuffle=True, filter_by="train_mask", anchor_percentage=0.5, - loader_id=None, - buffer_size=4, kafka_address="kafka:9092", ) data = loader.data @@ -78,6 +77,7 @@ def test_all_vertices(self): self.assertIn("train_mask", data.columns) self.assertIn("val_mask", data.columns) self.assertIn("test_mask", data.columns) + self.assertEqual(data.shape[0], 140) def test_sasl_plaintext(self): loader = NodePieceLoader( @@ -158,11 +158,9 @@ def test_init(self): shuffle=True, filter_by="train_mask", anchor_percentage=0.5, - loader_id=None, - buffer_size=4 ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 9) + self.assertIsNone(loader.num_batches) def test_iterate(self): loader = NodePieceLoader( @@ -173,12 +171,11 @@ def test_iterate(self): shuffle=True, filter_by="train_mask", anchor_percentage=0.5, - loader_id=None, - buffer_size=4 ) num_batches = 0 + batch_sizes = [] for data in loader: - # print(num_batches, data.head()) + # print(num_batches, data.shape, data.head()) self.assertIsInstance(data, DataFrame) self.assertIn("x", data.columns) self.assertIn("y", data.columns) @@ -187,8 +184,12 @@ def test_iterate(self): self.assertIn("train_mask", data.columns) self.assertIn("val_mask", data.columns) self.assertIn("test_mask", data.columns) + batch_sizes.append(data.shape[0]) num_batches += 1 self.assertEqual(num_batches, 9) + for i in batch_sizes[:-1]: + self.assertEqual(i, 16) + self.assertLessEqual(batch_sizes[-1], 16) def test_all_vertices(self): loader = NodePieceLoader( @@ -198,8 +199,6 @@ def test_all_vertices(self): shuffle=True, filter_by="train_mask", anchor_percentage=0.5, - loader_id=None, - buffer_size=4 ) data = loader.data # print(data) @@ -211,6 +210,7 @@ def test_all_vertices(self): self.assertIn("train_mask", data.columns) self.assertIn("val_mask", data.columns) self.assertIn("test_mask", data.columns) + self.assertEqual(data.shape[0], 140) class TestGDSHeteroNodePieceLoaderREST(unittest.TestCase): @@ -228,11 +228,9 @@ def test_init(self): batch_size=20, shuffle=True, filter_by=None, - loader_id=None, - buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 10) + self.assertIsNone(loader.num_batches) def test_iterate(self): loader = NodePieceLoader( @@ -244,23 +242,32 @@ def test_iterate(self): batch_size=20, shuffle=True, filter_by=None, - loader_id=None, - buffer_size=4, ) num_batches = 0 + batch_sizes = [] for data in loader: # print(num_batches, data) - self.assertIsInstance(data["v0"], DataFrame) - self.assertIsInstance(data["v1"], DataFrame) - self.assertIn("x", data["v0"].columns) - self.assertIn("relational_context", data["v0"].columns) - self.assertIn("anchors", data["v0"].columns) - self.assertIn("y", data["v0"].columns) - self.assertIn("x", data["v1"].columns) - self.assertIn("relational_context", data["v1"].columns) - self.assertIn("anchors", data["v1"].columns) + batchsize = 0 + self.assertTrue(("v0" in data) or ("v1" in data)) + if "v0" in data: + self.assertIsInstance(data["v0"], DataFrame) + self.assertIn("x", data["v0"].columns) + self.assertIn("relational_context", data["v0"].columns) + self.assertIn("anchors", data["v0"].columns) + self.assertIn("y", data["v0"].columns) + batchsize += data["v0"].shape[0] + if "v1" in data: + self.assertIsInstance(data["v1"], DataFrame) + self.assertIn("x", data["v1"].columns) + self.assertIn("relational_context", data["v1"].columns) + self.assertIn("anchors", data["v1"].columns) + batchsize += data["v1"].shape[0] num_batches += 1 + batch_sizes.append(batchsize) self.assertEqual(num_batches, 10) + for i in batch_sizes[:-1]: + self.assertEqual(i, 20) + self.assertLessEqual(batch_sizes[-1], 20) def test_all_vertices(self): loader = NodePieceLoader( @@ -272,8 +279,89 @@ def test_all_vertices(self): num_batches=1, shuffle=False, filter_by=None, - loader_id=None, - buffer_size=4, + ) + data = loader.data + # print(data) + self.assertIsInstance(data["v0"], DataFrame) + self.assertTupleEqual(data["v0"].shape, (76, 6)) + self.assertIsInstance(data["v1"], DataFrame) + self.assertIn("x", data["v0"].columns) + self.assertIn("y", data["v0"].columns) + self.assertIn("x", data["v1"].columns) + self.assertIn("anchors", data["v0"].columns) + self.assertIn("relational_context", data["v0"].columns) + self.assertIn("anchors", data["v1"].columns) + + +class TestGDSHeteroNodePieceLoaderKafka(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.conn = make_connection(graphname="hetero") + + def test_init(self): + loader = NodePieceLoader( + graph=self.conn, + compute_anchors=True, + anchor_percentage=0.5, + v_feats={"v0": ["x", "y"], + "v1": ["x"]}, + batch_size=20, + shuffle=True, + filter_by=None, + kafka_address="kafka:9092", + ) + self.assertTrue(is_query_installed(self.conn, loader.query_name)) + self.assertIsNone(loader.num_batches) + + def test_iterate(self): + loader = NodePieceLoader( + compute_anchors=True, + anchor_percentage=0.5, + graph=self.conn, + v_feats={"v0": ["x", "y"], + "v1": ["x"]}, + batch_size=20, + shuffle=True, + filter_by=None, + kafka_address="kafka:9092", + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + batchsize = 0 + self.assertTrue(("v0" in data) or ("v1" in data)) + if "v0" in data: + self.assertIsInstance(data["v0"], DataFrame) + self.assertIn("x", data["v0"].columns) + self.assertIn("relational_context", data["v0"].columns) + self.assertIn("anchors", data["v0"].columns) + self.assertIn("y", data["v0"].columns) + batchsize += data["v0"].shape[0] + if "v1" in data: + self.assertIsInstance(data["v1"], DataFrame) + self.assertIn("x", data["v1"].columns) + self.assertIn("relational_context", data["v1"].columns) + self.assertIn("anchors", data["v1"].columns) + batchsize += data["v1"].shape[0] + num_batches += 1 + batch_sizes.append(batchsize) + self.assertEqual(num_batches, 10) + for i in batch_sizes[:-1]: + self.assertEqual(i, 20) + self.assertLessEqual(batch_sizes[-1], 20) + + def test_all_vertices(self): + loader = NodePieceLoader( + graph=self.conn, + compute_anchors=True, + anchor_percentage=0.5, + v_feats={"v0": ["x", "y"], + "v1": ["x"]}, + num_batches=1, + shuffle=False, + filter_by=None, + kafka_address="kafka:9092", ) data = loader.data # print(data) @@ -290,17 +378,20 @@ def test_all_vertices(self): if __name__ == "__main__": suite = unittest.TestSuite() - suite.addTest(TestGDSNodePieceLoader("test_init")) - suite.addTest(TestGDSNodePieceLoader("test_iterate")) - suite.addTest(TestGDSNodePieceLoader("test_all_vertices")) - #suite.addTest(TestGDSNodePieceLoader("test_sasl_plaintext")) - # suite.addTest(TestGDSNodePieceLoader("test_sasl_ssl")) + suite.addTest(TestGDSNodePieceLoaderKafka("test_init")) + suite.addTest(TestGDSNodePieceLoaderKafka("test_iterate")) + suite.addTest(TestGDSNodePieceLoaderKafka("test_all_vertices")) + #suite.addTest(TestGDSNodePieceLoaderKafka("test_sasl_plaintext")) + #suite.addTest(TestGDSNodePieceLoaderKafka("test_sasl_ssl")) suite.addTest(TestGDSNodePieceLoaderREST("test_init")) suite.addTest(TestGDSNodePieceLoaderREST("test_iterate")) suite.addTest(TestGDSNodePieceLoaderREST("test_all_vertices")) suite.addTest(TestGDSHeteroNodePieceLoaderREST("test_init")) suite.addTest(TestGDSHeteroNodePieceLoaderREST("test_iterate")) suite.addTest(TestGDSHeteroNodePieceLoaderREST("test_all_vertices")) + suite.addTest(TestGDSHeteroNodePieceLoaderKafka("test_init")) + suite.addTest(TestGDSHeteroNodePieceLoaderKafka("test_iterate")) + suite.addTest(TestGDSHeteroNodePieceLoaderKafka("test_all_vertices")) runner = unittest.TextTestRunner(verbosity=2, failfast=True) runner.run(suite) From 2ac345bb0f39657a11d912c2f8d533b6f227ef55 Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Thu, 9 Nov 2023 15:28:10 -0800 Subject: [PATCH 24/36] test(BaseLoader): rm unneeded tests --- tests/test_gds_BaseLoader.py | 51 +++--------------------------------- 1 file changed, 4 insertions(+), 47 deletions(-) diff --git a/tests/test_gds_BaseLoader.py b/tests/test_gds_BaseLoader.py index 396ef453..621e6cf6 100644 --- a/tests/test_gds_BaseLoader.py +++ b/tests/test_gds_BaseLoader.py @@ -173,48 +173,6 @@ def test_read_vertex(self): ) assert_frame_equal(data, truth) - def test_read_vertex_shuffle(self): - read_task_q = Queue() - data_q = Queue(4) - exit_event = Event() - raw = ["99|1 0 0 1 |1|0|1\n", - "8|1 0 0 1 |1|1|1\n"] - for i in raw: - read_task_q.put(i) - thread = Thread( - target=self.loader._read_vertex_data, - kwargs=dict( - exit_event = exit_event, - in_q = read_task_q, - out_q = data_q, - batch_size = 2, - shuffle= True, - v_in_feats = ["x"], - v_out_labels = ["y"], - v_extra_feats = ["train_mask", "is_seed"], - v_attr_types = {"x": "INT", "y": "INT", "train_mask": "BOOL", "is_seed": "BOOL"}, - delimiter = "|" - ) - ) - thread.start() - data = data_q.get() - exit_event.set() - thread.join() - truth1 = pd.read_csv( - io.StringIO("".join(raw)), - header=None, - names=["vid", "x", "y", "train_mask", "is_seed"], - sep=self.loader.delimiter - ) - raw.reverse() - truth2 = pd.read_csv( - io.StringIO("".join(raw)), - header=None, - names=["vid", "x", "y", "train_mask", "is_seed"], - sep=self.loader.delimiter - ) - self.assertTrue((data==truth1).all().all() or (data==truth2).all().all()) - def test_read_vertex_callback(self): read_task_q = Queue() data_q = Queue(4) @@ -978,12 +936,11 @@ def test_read_bool_label(self): if __name__ == "__main__": suite = unittest.TestSuite() - # suite.addTest(TestGDSBaseLoader("test_get_schema")) - # suite.addTest(TestGDSBaseLoader("test_get_schema_no_primary_id_attr")) - # suite.addTest(TestGDSBaseLoader("test_validate_vertex_attributes")) - # suite.addTest(TestGDSBaseLoader("test_validate_edge_attributes")) + suite.addTest(TestGDSBaseLoader("test_get_schema")) + suite.addTest(TestGDSBaseLoader("test_get_schema_no_primary_id_attr")) + suite.addTest(TestGDSBaseLoader("test_validate_vertex_attributes")) + suite.addTest(TestGDSBaseLoader("test_validate_edge_attributes")) suite.addTest(TestGDSBaseLoader("test_read_vertex")) - suite.addTest(TestGDSBaseLoader("test_read_vertex_shuffle")) suite.addTest(TestGDSBaseLoader("test_read_vertex_callback")) suite.addTest(TestGDSBaseLoader("test_read_edge")) suite.addTest(TestGDSBaseLoader("test_read_edge_callback")) From 189bc29622b2b23bec661a5c02d816fdf9958dee Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Thu, 9 Nov 2023 15:55:35 -0800 Subject: [PATCH 25/36] fix(dataloaders): fix seeds type issue --- pyTigerGraph/gds/dataloaders.py | 57 +++++++++------------------------ 1 file changed, 15 insertions(+), 42 deletions(-) diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index ae5f78ed..9bd4c372 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -2148,7 +2148,8 @@ def _install_query(self, force: bool = False): if isinstance(self.v_in_feats, dict) or isinstance(self.e_in_feats, dict): # Multiple vertex types print_query_seed = "" - for idx, vtype in enumerate(self._seed_types): + print_query_other = "" + for idx, vtype in enumerate(self._vtypes): v_attr_names = ( self.v_in_feats.get(vtype, []) + self.v_out_labels.get(vtype, []) @@ -2162,26 +2163,17 @@ def _install_query(self, force: bool = False): .format("IF" if idx==0 else "ELSE IF", vtype, "+ delimiter + " + print_attr if v_attr_names else "") - print_query_seed += """ - END""" - query_replace["{SEEDVERTEXATTRS}"] = print_query_seed - print_query_other = "" - for idx, vtype in enumerate(self._vtypes): - v_attr_names = ( - self.v_in_feats.get(vtype, []) - + self.v_out_labels.get(vtype, []) - + self.v_extra_feats.get(vtype, []) - ) - v_attr_types = self._v_schema[vtype] - print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) print_query_other += """ {} s.type == "{}" THEN @@v_batch += (s.type + delimiter + stringify(getvid(s)) {} + delimiter + "0\\n")"""\ .format("IF" if idx==0 else "ELSE IF", vtype, "+ delimiter + " + print_attr if v_attr_names else "") + print_query_seed += """ + END""" print_query_other += """ END""" + query_replace["{SEEDVERTEXATTRS}"] = print_query_seed query_replace["{OTHERVERTEXATTRS}"] = print_query_other # Multiple edge types print_query = "" @@ -3803,7 +3795,8 @@ def _install_query(self, force: bool = False): query_replace["{VERTEXATTRS}"] = print_query # Multiple edge types print_query_seed = "" - for idx, etype in enumerate(self._seed_types): + print_query_other = "" + for idx, etype in enumerate(self._etypes): e_attr_names = ( self.e_in_feats.get(etype, []) + self.e_out_labels.get(etype, []) @@ -3816,25 +3809,16 @@ def _install_query(self, force: bool = False): @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ delimiter + "1\\n")"""\ .format("IF" if idx==0 else "ELSE IF", etype, "+ delimiter + " + print_attr if e_attr_names else "") - print_query_seed += """ - END""" - query_replace["{SEEDEDGEATTRS}"] = print_query_seed - print_query_other = "" - for idx, etype in enumerate(self._etypes): - e_attr_names = ( - self.e_in_feats.get(etype, []) - + self.e_out_labels.get(etype, []) - + self.e_extra_feats.get(etype, []) - ) - e_attr_types = self._e_schema[etype] - print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) print_query_other += """ {} e.type == "{}" THEN @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ delimiter + "0\\n")"""\ .format("IF" if idx==0 else "ELSE IF", etype, "+ delimiter + "+ print_attr if e_attr_names else "") + print_query_seed += """ + END""" print_query_other += """ END""" + query_replace["{SEEDEDGEATTRS}"] = print_query_seed query_replace["{OTHEREDGEATTRS}"] = print_query_other else: # Ignore vertex types @@ -4713,7 +4697,8 @@ def _install_query(self, force: bool = False): if isinstance(self.v_in_feats, dict) or isinstance(self.e_in_feats, dict): # Multiple vertex types print_query_seed = "" - for idx, vtype in enumerate(self._seed_types): + print_query_other = "" + for idx, vtype in enumerate(self._vtypes): v_attr_names = ( self.v_in_feats.get(vtype, []) + self.v_out_labels.get(vtype, []) @@ -4727,26 +4712,17 @@ def _install_query(self, force: bool = False): .format("IF" if idx==0 else "ELSE IF", vtype, "+ delimiter + " + print_attr if v_attr_names else "") - print_query_seed += """ - END""" - query_replace["{SEEDVERTEXATTRS}"] = print_query_seed - print_query_other = "" - for idx, vtype in enumerate(self._vtypes): - v_attr_names = ( - self.v_in_feats.get(vtype, []) - + self.v_out_labels.get(vtype, []) - + self.v_extra_feats.get(vtype, []) - ) - v_attr_types = self._v_schema[vtype] - print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) print_query_other += """ {} s.type == "{}" THEN @@v_batch += (s.type + delimiter + stringify(getvid(s)) {} + delimiter + "0\\n")"""\ .format("IF" if idx==0 else "ELSE IF", vtype, "+ delimiter + " + print_attr if v_attr_names else "") + print_query_seed += """ + END""" print_query_other += """ END""" + query_replace["{SEEDVERTEXATTRS}"] = print_query_seed query_replace["{OTHERVERTEXATTRS}"] = print_query_other # Generate select for each type of neighbors print_select = "" @@ -4914,7 +4890,6 @@ def fetch(self, vertices: list) -> None: break vertex_batch.update(i["vertex_batch"].splitlines()) edge_batch.update(i["edge_batch"].splitlines()) - print(len(vertex_batch), len(edge_batch)) data = self._parse_graph_data_to_df( raw = (vertex_batch, edge_batch), v_in_feats = self.v_in_feats, @@ -4929,8 +4904,6 @@ def fetch(self, vertices: list) -> None: primary_id = i["pids"], is_hetero = self.is_hetero, ) - print(data[0]) - print(data[1]) if self.output_format == "dataframe" or self.output_format== "df": vertices, edges = data if not self.is_hetero: From 9a9e78c67718298d71414a316869992497502b94 Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Fri, 10 Nov 2023 12:58:59 -0800 Subject: [PATCH 26/36] test(GDS): fix num_batches issue --- tests/test_gds_GDS.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_gds_GDS.py b/tests/test_gds_GDS.py index b247ac05..f6364f88 100644 --- a/tests/test_gds_GDS.py +++ b/tests/test_gds_GDS.py @@ -26,7 +26,7 @@ def test_neighborLoader(self): buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 9) + self.assertEqual(loader.batch_size, 16) def test_neighborLoader_multiple_filters(self): loaders = self.conn.gds.neighborLoader( @@ -60,7 +60,7 @@ def test_graphLoader(self): buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 11) + self.assertEqual(loader.batch_size, 1024) def test_vertexLoader(self): loader = self.conn.gds.vertexLoader( @@ -72,7 +72,7 @@ def test_vertexLoader(self): buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 9) + self.assertEqual(loader.batch_size, 16) def test_edgeLoader(self): loader = self.conn.gds.edgeLoader( @@ -83,7 +83,7 @@ def test_edgeLoader(self): buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 11) + self.assertEqual(loader.batch_size, 1024) def test_edgeNeighborLoader(self): loader = self.conn.gds.edgeNeighborLoader( @@ -100,7 +100,7 @@ def test_edgeNeighborLoader(self): buffer_size=4, ) self.assertTrue(is_query_installed(self.conn, loader.query_name)) - self.assertEqual(loader.num_batches, 11) + self.assertEqual(loader.batch_size, 1024) def test_configureKafka(self): self.conn.gds.configureKafka(kafka_address="kafka:9092") From 17390dc639fc9029f79f2ca02abcd0182a1295d2 Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Fri, 10 Nov 2023 12:59:48 -0800 Subject: [PATCH 27/36] fix(Trainer): rm num_batches --- pyTigerGraph/gds/trainer.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/pyTigerGraph/gds/trainer.py b/pyTigerGraph/gds/trainer.py index f44ca992..58e56e87 100644 --- a/pyTigerGraph/gds/trainer.py +++ b/pyTigerGraph/gds/trainer.py @@ -16,6 +16,7 @@ import time import os import warnings +import math class BaseCallback(): """Base class for training callbacks. @@ -145,7 +146,7 @@ def on_train_step_end(self, trainer): trainer.update_train_step_metrics(metric.get_metrics()) metric.reset_metrics() trainer.update_train_step_metrics({"global_step": trainer.cur_step}) - trainer.update_train_step_metrics({"epoch": int(trainer.cur_step/trainer.train_loader.num_batches)}) + trainer.update_train_step_metrics({"epoch": trainer.cur_epoch}) def on_eval_start(self, trainer): """NO DOC""" @@ -407,12 +408,17 @@ def train(self, num_epochs=None, max_num_steps=None): Defaults to the length of the `training_dataloader` """ if num_epochs: - self.max_num_steps = self.train_loader.num_batches * num_epochs - else: + self.max_num_steps = math.inf + self.num_epochs = num_epochs + elif max_num_steps: self.max_num_steps = max_num_steps - self.num_epochs = num_epochs + self.num_epochs = math.inf + else: + self.max_num_steps = math.inf + self.num_epochs = 1 self.cur_step = 0 - while self.cur_step < self.max_num_steps: + self.cur_epoch = 0 + while self.cur_step < self.max_num_steps and self.cur_epoch < self.num_epochs: for callback in self.callbacks: callback.on_epoch_start(trainer=self) for batch in self.train_loader: @@ -432,7 +438,7 @@ def train(self, num_epochs=None, max_num_steps=None): self.cur_step += 1 for callback in self.callbacks: callback.on_train_step_end(trainer=self) - + self.cur_epoch += 1 for callback in self.callbacks: callback.on_epoch_end(trainer=self) From 0d03fe5911335a557edeeb16f5637e7c94817314 Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Fri, 10 Nov 2023 13:24:31 -0800 Subject: [PATCH 28/36] test(featurizer): fix patch_ver issue --- tests/test_gds_featurizer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_gds_featurizer.py b/tests/test_gds_featurizer.py index b873d404..d6266c93 100644 --- a/tests/test_gds_featurizer.py +++ b/tests/test_gds_featurizer.py @@ -23,7 +23,11 @@ def test_get_db_version(self): major_ver, minor_ver, patch_ver = self.featurizer._get_db_version() self.assertIsNotNone(int(major_ver)) self.assertIsNotNone(int(minor_ver)) - self.assertIsNotNone(int(patch_ver)) + try: + patch_ver = int(patch_ver) + except: + pass + self.assertIsNotNone(patch_ver) self.assertIsInstance(self.featurizer.algo_ver, str) def test_get_algo_dict(self): From 22a02b69ac17256ad4ab1b37225f730158677675 Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Fri, 17 Nov 2023 15:54:36 -0800 Subject: [PATCH 29/36] fix: bool error when num_batches is None --- pyTigerGraph/gds/trainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pyTigerGraph/gds/trainer.py b/pyTigerGraph/gds/trainer.py index 58e56e87..03e650c0 100644 --- a/pyTigerGraph/gds/trainer.py +++ b/pyTigerGraph/gds/trainer.py @@ -210,7 +210,7 @@ def on_epoch_start(self, trainer): self.epoch_bar = self.tqdm(desc="Epochs", total=trainer.num_epochs) else: self.epoch_bar = self.tqdm(desc="Training Steps", total=trainer.max_num_steps) - if not(self.batch_bar): + if self.batch_bar is None: self.batch_bar = self.tqdm(desc="Training Batches", total=trainer.train_loader.num_batches) def on_train_step_end(self, trainer): @@ -218,20 +218,20 @@ def on_train_step_end(self, trainer): logger = logging.getLogger(__name__) logger.info("train_step:"+str(trainer.get_train_step_metrics())) if self.tqdm: - if self.batch_bar: + if self.batch_bar is not None: self.batch_bar.update(1) def on_eval_start(self, trainer): """NO DOC""" trainer.reset_eval_metrics() if self.tqdm: - if not(self.valid_bar): + if self.valid_bar is None: self.valid_bar = self.tqdm(desc="Eval Batches", total=trainer.eval_loader.num_batches) def on_eval_step_end(self, trainer): """NO DOC""" if self.tqdm: - if self.valid_bar: + if self.valid_bar is not None: self.valid_bar.update(1) def on_eval_end(self, trainer): @@ -240,7 +240,7 @@ def on_eval_end(self, trainer): logger.info("evaluation:"+str(trainer.get_eval_metrics())) trainer.model.train() if self.tqdm: - if self.valid_bar: + if self.valid_bar is not None: self.valid_bar.close() self.valid_bar = None @@ -249,7 +249,7 @@ def on_epoch_end(self, trainer): if self.tqdm: if self.epoch_bar: self.epoch_bar.update(1) - if self.batch_bar: + if self.batch_bar is not None: self.batch_bar.close() self.batch_bar = None trainer.eval() From ed6a53cbaaadbac0b610f3ca0fcc873d1fc54d2c Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Wed, 13 Dec 2023 11:42:16 -0800 Subject: [PATCH 30/36] fix(test_HGTLoader): fix error when edge not present --- tests/test_gds_HGTLoader.py | 54 ++++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/tests/test_gds_HGTLoader.py b/tests/test_gds_HGTLoader.py index cfa5f8f3..6278f8a8 100644 --- a/tests/test_gds_HGTLoader.py +++ b/tests/test_gds_HGTLoader.py @@ -154,18 +154,21 @@ def test_iterate_pyg(self): data['v2', 'v2v2', 'v2']["edge_index"].shape[1] > 0 and data['v2', 'v2v2', 'v2']["edge_index"].shape[1] <= 966 ) - self.assertTrue( - data['v0', 'v0v0', 'v0']["edge_index"].shape[1] > 0 - and data['v0', 'v0v0', 'v0']["edge_index"].shape[1] <= 710 - ) - self.assertTrue( - data['v1', 'v1v1', 'v1']["edge_index"].shape[1] > 0 - and data['v1', 'v1v1', 'v1']["edge_index"].shape[1] <= 1044 - ) - self.assertTrue( - data['v1', 'v1v2', 'v2']["edge_index"].shape[1] > 0 - and data['v1', 'v1v2', 'v2']["edge_index"].shape[1] <= 1038 - ) + if ('v0', 'v0v0', 'v0') in data.edge_types: + self.assertTrue( + data['v0', 'v0v0', 'v0']["edge_index"].shape[1] > 0 + and data['v0', 'v0v0', 'v0']["edge_index"].shape[1] <= 710 + ) + if ('v1', 'v1v1', 'v1') in data.edge_types: + self.assertTrue( + data['v1', 'v1v1', 'v1']["edge_index"].shape[1] > 0 + and data['v1', 'v1v1', 'v1']["edge_index"].shape[1] <= 1044 + ) + if ('v1', 'v1v2', 'v2') in data.edge_types: + self.assertTrue( + data['v1', 'v1v2', 'v2']["edge_index"].shape[1] > 0 + and data['v1', 'v1v2', 'v2']["edge_index"].shape[1] <= 1038 + ) num_batches += 1 self.assertEqual(num_batches, 7) for i in batch_sizes[:-1]: @@ -349,18 +352,21 @@ def test_iterate_pyg(self): data['v2', 'v2v2', 'v2']["edge_index"].shape[1] > 0 and data['v2', 'v2v2', 'v2']["edge_index"].shape[1] <= 966 ) - self.assertTrue( - data['v0', 'v0v0', 'v0']["edge_index"].shape[1] > 0 - and data['v0', 'v0v0', 'v0']["edge_index"].shape[1] <= 710 - ) - self.assertTrue( - data['v1', 'v1v1', 'v1']["edge_index"].shape[1] > 0 - and data['v1', 'v1v1', 'v1']["edge_index"].shape[1] <= 1044 - ) - self.assertTrue( - data['v1', 'v1v2', 'v2']["edge_index"].shape[1] > 0 - and data['v1', 'v1v2', 'v2']["edge_index"].shape[1] <= 1038 - ) + if ('v0', 'v0v0', 'v0') in data.edge_types: + self.assertTrue( + data['v0', 'v0v0', 'v0']["edge_index"].shape[1] > 0 + and data['v0', 'v0v0', 'v0']["edge_index"].shape[1] <= 710 + ) + if ('v1', 'v1v1', 'v1') in data.edge_types: + self.assertTrue( + data['v1', 'v1v1', 'v1']["edge_index"].shape[1] > 0 + and data['v1', 'v1v1', 'v1']["edge_index"].shape[1] <= 1044 + ) + if ('v1', 'v1v2', 'v2') in data.edge_types: + self.assertTrue( + data['v1', 'v1v2', 'v2']["edge_index"].shape[1] > 0 + and data['v1', 'v1v2', 'v2']["edge_index"].shape[1] <= 1038 + ) num_batches += 1 self.assertEqual(num_batches, 7) for i in batch_sizes[:-1]: From 3f257a29bc262a75a38ae66f7bb7ab9e5f313d25 Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Thu, 14 Dec 2023 11:23:27 -0800 Subject: [PATCH 31/36] fix(dataloaders): check edge number returned from DB --- pyTigerGraph/gds/dataloaders.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index 9bd4c372..3d2d1d25 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -2670,6 +2670,8 @@ def __init__( num_edges += tmp else: num_edges += 2*tmp + if num_edges==0: + raise ValueError("Cannot find any edge as seed. Please check your configuration and data. If they all look good, please use batch_size instead of num_batches or refresh metadata following https://docs.tigergraph.com/tigergraph-server/current/api/built-in-endpoints#_parameters_15") self.batch_size = math.ceil(num_edges / num_batches) self.num_batches = num_batches # Initialize the exporter @@ -3736,6 +3738,8 @@ def __init__( num_edges += tmp else: num_edges += 2*tmp + if num_edges==0: + raise ValueError("Cannot find any edge as seed. Please check the configuration and the data. If they all look right, please use batch_size instead of num_batches or refresh metadata following https://docs.tigergraph.com/tigergraph-server/current/api/built-in-endpoints#_parameters_15") self.batch_size = math.ceil(num_edges / num_batches) self.num_batches = num_batches # Initialize parameters for the query From 98cdc1f3055352e7e488e09c16ffb6949d551ae4 Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Wed, 20 Dec 2023 16:24:32 -0800 Subject: [PATCH 32/36] feat: update baseloader and unit test for new queries --- pyTigerGraph/gds/dataloaders.py | 149 ++++++++++++++++++++------------ tests/test_gds_BaseLoader.py | 81 ++++++++++------- 2 files changed, 144 insertions(+), 86 deletions(-) diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index 3d2d1d25..7c2ac792 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -559,7 +559,7 @@ def _request_graph_rest( ) # Put raw data into reading queue. for i in resp: - read_task_q.put((i["vertex_batch"], i["edge_batch"])) + read_task_q.put((i["vertex_batch"], i["edge_batch"], i["seed"])) read_task_q.put(None) @staticmethod @@ -598,12 +598,13 @@ def _download_graph_kafka( if key == "STOP": read_task_q.put(None) empty = True - continue + break if key.startswith("vertex"): companion_key = key.replace("vertex", "edge") if companion_key in buffer: read_task_q.put((message.value.decode("utf-8"), - buffer[companion_key])) + buffer[companion_key], + key.split("_", 2)[-1])) del buffer[companion_key] else: buffer[key] = message.value.decode("utf-8") @@ -611,7 +612,8 @@ def _download_graph_kafka( companion_key = key.replace("edge", "vertex") if companion_key in buffer: read_task_q.put((buffer[companion_key], - message.value.decode("utf-8"))) + message.value.decode("utf-8"), + key.split("_", 2)[-1])) del buffer[companion_key] else: buffer[key] = message.value.decode("utf-8") @@ -619,6 +621,8 @@ def _download_graph_kafka( warnings.warn( "Unrecognized key {} for messages in kafka".format(key) ) + if empty: + break @staticmethod def _download_unimode_kafka( @@ -658,7 +662,8 @@ def _read_graph_data( add_self_loop: bool = False, delimiter: str = "|", is_hetero: bool = False, - callback_fn: Callable = None + callback_fn: Callable = None, + seed_type: str = "" ) -> NoReturn: # Import the right libraries based on output format out_format = out_format.lower() @@ -706,6 +711,7 @@ def _read_graph_data( vertex_buffer = [] edge_buffer = [] buffer_size = 0 + seeds = set() is_empty = False last_batch = False while (not exit_event.is_set()) and (not is_empty): @@ -718,16 +724,21 @@ def _read_graph_data( if buffer_size > 0: last_batch = True else: - vertex_buffer.extend(raw[0].splitlines()) - edge_buffer.extend(raw[1].splitlines()) + vertex_buffer.extend(raw[0].strip().split("\n ")) + edge_buffer.extend(raw[1].strip().split("\n ")) + seeds.add(raw[2]) buffer_size += 1 if (buffer_size < batch_size) and (not last_batch): continue try: vertex_buffer_d = dict.fromkeys(vertex_buffer) edge_buffer_d = dict.fromkeys(edge_buffer) + if seed_type: + raw_data = (vertex_buffer_d.keys(), edge_buffer_d.keys(), seeds) + else: + raw_data = (vertex_buffer_d.keys(), edge_buffer_d.keys()) data = BaseLoader._parse_graph_data_to_df( - raw = (vertex_buffer_d.keys(), edge_buffer_d.keys()), + raw = raw_data, v_in_feats = v_in_feats, v_out_labels = v_out_labels, v_extra_feats = v_extra_feats, @@ -739,6 +750,7 @@ def _read_graph_data( delimiter = delimiter, primary_id = {}, is_hetero = is_hetero, + seed_type = seed_type ) if out_format == "dataframe" or out_format == "df": vertices, edges = data @@ -816,6 +828,7 @@ def _read_graph_data( out_format, v_in_feats, v_out_labels, v_extra_feats, v_attr_types, e_in_feats, e_out_labels, e_extra_feats, e_attr_types, delimiter)) vertex_buffer.clear() edge_buffer.clear() + seeds.clear() buffer_size = 0 out_q.put(None) @@ -945,22 +958,37 @@ def _parse_vertex_data( v_extra_feats: Union[list, dict] = [], v_attr_types: dict = {}, delimiter: str = "|", - is_hetero: bool = False + is_hetero: bool = False, + seeds: list = [] ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: """Parse raw vertex data into dataframes. """ # Read in vertex CSVs as dataframes # Each row is in format vid,v_in_feats,v_out_labels,v_extra_feats - # or vtype,vid,v_in_feats,v_out_labels,v_extra_feats - v_file = (line.strip().split(delimiter) for line in raw) + # or vtype,vid,v_in_feats,v_out_labels,v_extra_feats + v_file = (line.strip().split(delimiter) for line in raw if line) + # If seeds are given, create the is_seed column + if seeds: + seed_df = pd.DataFrame({ + "vid": list(seeds), + "is_seed": True + }) if not is_hetero: # String of vertices in format vid,v_in_feats,v_out_labels,v_extra_feats v_attributes = ["vid"] + v_in_feats + v_out_labels + v_extra_feats + if seeds: + try: + v_attributes.remove("is_seed") + except ValueError: + pass data = pd.DataFrame(v_file, columns=v_attributes, dtype="object") for v_attr in v_extra_feats: if v_attr_types.get(v_attr, "") == "MAP": # I am sorry that this is this ugly... data[v_attr] = data[v_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) + if seeds: + data = data.merge(seed_df, on="vid", how="left") + data.fillna({"is_seed": False}, inplace=True) else: v_file_dict = defaultdict(list) for line in v_file: @@ -971,11 +999,19 @@ def _parse_vertex_data( v_in_feats.get(vtype, []) + \ v_out_labels.get(vtype, []) + \ v_extra_feats.get(vtype, []) + if seeds: + try: + v_attributes.remove("is_seed") + except ValueError: + pass data[vtype] = pd.DataFrame(v_file_dict[vtype], columns=v_attributes, dtype="object") for v_attr in v_extra_feats.get(vtype, []): if v_attr_types[vtype][v_attr] == "MAP": # I am sorry that this is this ugly... data[vtype][v_attr] = data[vtype][v_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) + if seeds: + data[vtype] = data[vtype].merge(seed_df, on="vid", how="left") + data[vtype].fillna({"is_seed": False}, inplace=True) return data @staticmethod @@ -986,36 +1022,68 @@ def _parse_edge_data( e_extra_feats: Union[list, dict] = [], e_attr_types: dict = {}, delimiter: str = "|", - is_hetero: bool = False + is_hetero: bool = False, + seeds: list = [] ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: """Parse raw edge data into dataframes. """ # Read in edge CSVs as dataframes # Each row is in format source_vid,target_vid,e_in_feats,e_out_labels,e_extra_feats # or etype,source_vid,target_vid,e_in_feats,e_out_labels,e_extra_feats - e_file = (line.strip().split(delimiter) for line in raw) + e_file = (line.strip().split(delimiter) for line in raw if line) if not is_hetero: e_attributes = ["source", "target"] + e_in_feats + e_out_labels + e_extra_feats + if seeds: + try: + e_attributes.remove("is_seed") + except ValueError: + pass data = pd.DataFrame(e_file, columns=e_attributes, dtype="object") for e_attr in e_extra_feats: if e_attr_types.get(e_attr, "") == "MAP": # I am sorry that this is this ugly... data[e_attr] = data[e_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) + # If seeds are given, create the is_seed column + if seeds: + seed_df = pd.DataFrame.from_records( + [i.split("_") for i in seeds], + columns=["source", "etype", "target"]) + seed_df["is_seed"] = True + del seed_df["etype"] + data = data.merge(seed_df, on=["source", "target"], how="left") + data.fillna({"is_seed": False}, inplace=True) else: e_file_dict = defaultdict(list) for line in e_file: e_file_dict[line[0]].append(line[1:]) data = {} + # If seeds are given, create the is_seed column + if seeds: + seed_df = pd.DataFrame.from_records( + [i.split("_") for i in seeds], + columns=["source", "etype", "target"]) + seed_df["is_seed"] = True for etype in e_file_dict: e_attributes = ["source", "target"] + \ e_in_feats.get(etype, []) + \ e_out_labels.get(etype, []) + \ e_extra_feats.get(etype, []) + if seeds: + try: + e_attributes.remove("is_seed") + except ValueError: + pass data[etype] = pd.DataFrame(e_file_dict[etype], columns=e_attributes, dtype="object") for e_attr in e_extra_feats.get(etype, []): if e_attr_types[etype][e_attr] == "MAP": # I am sorry that this is this ugly... data[etype][e_attr] = data[etype][e_attr].apply(lambda x: {y.split(",")[0].strip("("): y.split(",")[1].strip(")") for y in x.strip("[").strip("]").split(" ")[:-1]} if x != "[]" else {}) + if seeds: + tmp_df = seed_df[seed_df["etype"]==etype] + if len(tmp_df)>0: + data[etype] = data[etype].merge( + tmp_df[["source", "target", "is_seed"]], on=["source", "target"], how="left") + data[etype].fillna({"is_seed": False}, inplace=True) return data @staticmethod @@ -1031,13 +1099,18 @@ def _parse_graph_data_to_df( e_attr_types: dict = {}, delimiter: str = "|", primary_id: dict = {}, - is_hetero: bool = False + is_hetero: bool = False, + seed_type: str = "" ) -> Union[Tuple[pd.DataFrame, pd.DataFrame], Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]]: """Parse raw data into dataframes. """ # Read in vertex and edge CSVs as dataframes # A pair of in-memory CSVs (vertex, edge) - v_file, e_file = raw + if len(raw) == 3: + v_file, e_file, seed_file = raw + else: + v_file, e_file = raw + seed_file = [] vertices = BaseLoader._parse_vertex_data( raw = v_file, v_in_feats = v_in_feats, @@ -1045,7 +1118,9 @@ def _parse_graph_data_to_df( v_extra_feats = v_extra_feats, v_attr_types = v_attr_types, delimiter = delimiter, - is_hetero = is_hetero) + is_hetero = is_hetero, + seeds = seed_file if seed_type=="vertex" else [] + ) if primary_id: id_map = pd.DataFrame({"vid": primary_id.keys(), "primary_id": primary_id.values()}, dtype="object") @@ -1063,7 +1138,8 @@ def _parse_graph_data_to_df( e_extra_feats = e_extra_feats, e_attr_types = e_attr_types, delimiter = delimiter, - is_hetero = is_hetero + is_hetero = is_hetero, + seeds = seed_file if seed_type=="edge" else [] ) return (vertices, edges) @@ -1224,19 +1300,6 @@ def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, # Reformat as a graph. # Need to have a pair of tables for edges and vertices. vertices, edges = raw - # Dedupe vertices if there is is_seed column as the same vertex might be seed and non-seed at the same time - if not is_hetero: - if "is_seed" in vertices.columns: - seeds = set(vertices[vertices.is_seed.astype(int).astype(bool)]["vid"]) - vertices.loc[vertices.vid.isin(seeds), "is_seed"] = True - vertices.drop_duplicates(subset="vid", inplace=True, ignore_index=True) - else: - for vtype in vertices: - df = vertices[vtype] - if "is_seed" in df.columns: - seeds = set(df[df.is_seed.astype(int).astype(bool)]["vid"]) - df.loc[df.vid.isin(seeds), "is_seed"] = True - df.drop_duplicates(subset="vid", inplace=True, ignore_index=True) edgelist = BaseLoader._get_edgelist(vertices, edges, is_hetero, e_attr_types) if not is_hetero: # Deal with edgelist first @@ -1438,19 +1501,6 @@ def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, # Reformat as a graph. # Need to have a pair of tables for edges and vertices. vertices, edges = raw - # Dedupe vertices if there is is_seed column as the same vertex might be seed and non-seed at the same time - if not is_hetero: - if "is_seed" in vertices.columns: - seeds = set(vertices[vertices.is_seed.astype(int).astype(bool)]["vid"]) - vertices.loc[vertices.vid.isin(seeds), "is_seed"] = True - vertices.drop_duplicates(subset="vid", inplace=True, ignore_index=True) - else: - for vtype in vertices: - df = vertices[vtype] - if "is_seed" in df.columns: - seeds = set(df[df.is_seed.astype(int).astype(bool)]["vid"]) - df.loc[df.vid.isin(seeds), "is_seed"] = True - df.drop_duplicates(subset="vid", inplace=True, ignore_index=True) edgelist = BaseLoader._get_edgelist(vertices, edges, is_hetero, e_attr_types) if not is_hetero: # Deal with edgelist first @@ -1597,19 +1647,6 @@ def add_sep_attr(attr_names: list, attr_types: dict, attr_df: pd.DataFrame, # Reformat as a graph. # Need to have a pair of tables for edges and vertices. vertices, edges = raw - # Dedupe vertices if there is is_seed column as the same vertex might be seed and non-seed at the same time - if not is_hetero: - if "is_seed" in vertices.columns: - seeds = set(vertices[vertices.is_seed.astype(int).astype(bool)]["vid"]) - vertices.loc[vertices.vid.isin(seeds), "is_seed"] = True - vertices.drop_duplicates(subset="vid", inplace=True, ignore_index=True) - else: - for vtype in vertices: - df = vertices[vtype] - if "is_seed" in df.columns: - seeds = set(df[df.is_seed.astype(int).astype(bool)]["vid"]) - df.loc[df.vid.isin(seeds), "is_seed"] = True - df.drop_duplicates(subset="vid", inplace=True, ignore_index=True) edgelist = BaseLoader._get_edgelist(vertices, edges, is_hetero, e_attr_types) if not is_hetero: # Deal with edgelist first diff --git a/tests/test_gds_BaseLoader.py b/tests/test_gds_BaseLoader.py index 621e6cf6..68ad7bed 100644 --- a/tests/test_gds_BaseLoader.py +++ b/tests/test_gds_BaseLoader.py @@ -270,8 +270,9 @@ def test_read_graph_out_df(self): data_q = Queue(4) exit_event = Event() raw = ( - "99|1 0 0 1 |1|0|1\n8|1 0 0 1 |1|1|1\n", - "1|2|0.1|2021|1|0\n2|1|1.5|2020|0|1\n", + "99|1 0 0 1 |1|0\n 8|1 0 0 1 |1|1\n ", + "1|2|0.1|2021|1|0\n 2|1|1.5|2020|0|1\n ", + "99" ) read_task_q.put(raw) thread = Thread( @@ -289,7 +290,8 @@ def test_read_graph_out_df(self): e_out_labels = ["y"], e_extra_feats = ["is_train"], e_attr_types = {"x": "FLOAT", "time": "INT", "y": "INT", "is_train": "BOOL"}, - delimiter = "|" + delimiter = "|", + seed_type = "vertex" ) ) thread.start() @@ -299,9 +301,10 @@ def test_read_graph_out_df(self): vertices = pd.read_csv( io.StringIO(raw[0]), header=None, - names=["vid", "x", "y", "train_mask", "is_seed"], + names=["vid", "x", "y", "train_mask"], sep=self.loader.delimiter ) + vertices["is_seed"] = [True, False] edges = pd.read_csv( io.StringIO(raw[1]), header=None, @@ -316,8 +319,9 @@ def test_read_graph_out_df_callback(self): data_q = Queue(4) exit_event = Event() raw = ( - "99|1 0 0 1 |1|0|1\n8|1 0 0 1 |1|1|1\n", - "1|2|0.1|2021|1|0\n2|1|1.5|2020|0|1\n", + "99|1 0 0 1 |1|0|1\n 8|1 0 0 1 |1|1|1\n ", + "1|2|0.1|2021|1|0\n 2|1|1.5|2020|0|1\n ", + "" ) read_task_q.put(raw) thread = Thread( @@ -351,8 +355,9 @@ def test_read_graph_out_pyg(self): data_q = Queue(4) exit_event = Event() raw = ( - "99|1 0 0 1 |1|0|Alex|1\n8|1 0 0 1 |1|1|Bill|0\n", - "99|8|0.1|2021|1|0|a b \n8|99|1.5|2020|0|1|c d \n", + "99|1 0 0 1 |1|0|Alex\n 8|1 0 0 1 |1|1|Bill\n ", + "99|8|0.1|2021|1|0|a b \n 8|99|1.5|2020|0|1|c d \n ", + "99" ) read_task_q.put(raw) thread = Thread( @@ -378,7 +383,8 @@ def test_read_graph_out_pyg(self): e_out_labels = ["y"], e_extra_feats = ["is_train", "category"], e_attr_types = {"x": "DOUBLE", "time": "INT", "y": "INT", "is_train": "BOOL", "category": "LIST:STRING"}, - delimiter = "|" + delimiter = "|", + seed_type = "vertex" ) ) thread.start() @@ -405,8 +411,9 @@ def test_read_graph_out_dgl(self): data_q = Queue(4) exit_event = Event() raw = ( - "99|1 0 0 1 |1|0|Alex|1\n8|1 0 0 1 |1|1|Bill|0\n", - "99|8|0.1|2021|1|0|a b \n8|99|1.5|2020|0|1|c d \n", + "99|1 0 0 1 |1|0|Alex\n 8|1 0 0 1 |1|1|Bill\n ", + "99|8|0.1|2021|1|0|a b \n 8|99|1.5|2020|0|1|c d \n ", + "99" ) read_task_q.put(raw) thread = Thread( @@ -432,7 +439,8 @@ def test_read_graph_out_dgl(self): e_out_labels = ["y"], e_extra_feats = ["is_train", "category"], e_attr_types = {"x": "DOUBLE", "time": "INT", "y": "INT", "is_train": "BOOL", "category": "LIST:STRING"}, - delimiter = "|" + delimiter = "|", + seed_type = "vertex" ) ) thread.start() @@ -499,7 +507,7 @@ def test_read_graph_no_attr(self): read_task_q = Queue() data_q = Queue(4) exit_event = Event() - raw = ("99|1\n8|0\n", "99|8\n8|99\n") + raw = ("99\n 8\n ", "99|8\n 8|99\n ", "99") read_task_q.put(raw) thread = Thread( target=self.loader._read_graph_data, @@ -518,7 +526,8 @@ def test_read_graph_no_attr(self): "name": "STRING", "is_seed": "BOOL", }, - delimiter = "|" + delimiter = "|", + seed_type = "vertex" ) ) thread.start() @@ -534,8 +543,9 @@ def test_read_graph_no_edge(self): data_q = Queue(4) exit_event = Event() raw = ( - "99|1 0 0 1 |1|0|Alex|1\n8|1 0 0 1 |1|1|Bill|0\n", + "99|1 0 0 1 |1|0|Alex\n 8|1 0 0 1 |1|1|Bill\n ", "", + "99" ) read_task_q.put(raw) thread = Thread( @@ -561,7 +571,8 @@ def test_read_graph_no_edge(self): e_out_labels = ["y"], e_extra_feats = ["is_train"], e_attr_types = {"x": "DOUBLE", "time": "INT", "y": "INT", "is_train": "BOOL"}, - delimiter = "|" + delimiter = "|", + seed_type = "vertex" ) ) thread.start() @@ -584,8 +595,9 @@ def test_read_hetero_graph_out_pyg(self): data_q = Queue(4) exit_event = Event() raw = ( - "People|99|1 0 0 1 |1|0|Alex|1\nPeople|8|1 0 0 1 |1|1|Bill|0\nCompany|2|0.3|0\n", - "Colleague|99|8|0.1|2021|1|0\nColleague|8|99|1.5|2020|0|1\nWork|99|2\nWork|2|8\n", + "People|99|1 0 0 1 |1|0|Alex\n People|8|1 0 0 1 |1|1|Bill\n Company|2|0.3\n ", + "Colleague|99|8|0.1|2021|1|0\n Colleague|8|99|1.5|2020|0|1\n Work|99|2\n Work|2|8\n ", + "99" ) read_task_q.put(raw) thread = Thread( @@ -628,7 +640,8 @@ def test_read_hetero_graph_out_pyg(self): "IsDirected": False} }, delimiter = "|", - is_hetero = True + is_hetero = True, + seed_type = "vertex" ) ) thread.start() @@ -665,8 +678,9 @@ def test_read_hetero_graph_no_attr(self): data_q = Queue(4) exit_event = Event() raw = ( - "People|99|1\nPeople|8|0\nCompany|2|0\n", - "Colleague|99|8\nColleague|8|99\nWork|99|2\nWork|2|8\n", + "People|99\n People|8\n Company|2\n ", + "Colleague|99|8\n Colleague|8|99\n Work|99|2\n Work|2|8\n ", + "99" ) read_task_q.put(raw) thread = Thread( @@ -709,7 +723,8 @@ def test_read_hetero_graph_no_attr(self): "IsDirected": False} }, delimiter = "|", - is_hetero = True + is_hetero = True, + seed_type = "vertex" ) ) thread.start() @@ -731,8 +746,9 @@ def test_read_hetero_graph_no_edge(self): data_q = Queue(4) exit_event = Event() raw = ( - "People|99|1 0 0 1 |1|0|Alex|1\nPeople|8|1 0 0 1 |1|1|Bill|0\nCompany|2|0.3|0\n", + "People|99|1 0 0 1 |1|0|Alex\n People|8|1 0 0 1 |1|1|Bill\n Company|2|0.3\n ", "", + "99" ) read_task_q.put(raw) thread = Thread( @@ -775,7 +791,8 @@ def test_read_hetero_graph_no_edge(self): "IsDirected": False} }, delimiter = "|", - is_hetero = True + is_hetero = True, + seed_type = "vertex" ) ) thread.start() @@ -802,8 +819,9 @@ def test_read_hetero_graph_out_dgl(self): data_q = Queue(4) exit_event = Event() raw = ( - "People|99|1 0 0 1 |1|0|Alex|1\nPeople|8|1 0 0 1 |1|1|Bill|0\nCompany|2|0.3|0\n", - "Colleague|99|8|0.1|2021|1|0\nColleague|8|99|1.5|2020|0|1\nWork|99|2|a b \nWork|2|8|c d \n", + "People|99|1 0 0 1 |1|0|Alex\n People|8|1 0 0 1 |1|1|Bill\n Company|2|0.3\n ", + "Colleague|99|8|0.1|2021|1|0\n Colleague|8|99|1.5|2020|0|1\n Work|99|2|a b \n Work|2|8|c d \n ", + "99" ) read_task_q.put(raw) thread = Thread( @@ -847,7 +865,8 @@ def test_read_hetero_graph_out_dgl(self): "category": "LIST:STRING"} }, delimiter = "|", - is_hetero = True + is_hetero = True, + seed_type = "vertex" ) ) thread.start() @@ -885,8 +904,9 @@ def test_read_bool_label(self): data_q = Queue(4) exit_event = Event() raw = ( - "99|1 0 0 1 |1|0|Alex|1\n8|1 0 0 1 |1|1|Bill|0\n", - "99|8|0.1|2021|1|0\n8|99|1.5|2020|0|1\n", + "99|1 0 0 1 |1|0|Alex\n 8|1 0 0 1 |1|1|Bill\n ", + "99|8|0.1|2021|1|0\n 8|99|1.5|2020|0|1\n ", + "99" ) read_task_q.put(raw) thread = Thread( @@ -912,7 +932,8 @@ def test_read_bool_label(self): e_out_labels = ["y"], e_extra_feats = ["is_train"], e_attr_types = {"x": "DOUBLE", "time": "INT", "y": "BOOL", "is_train": "BOOL"}, - delimiter = "|" + delimiter = "|", + seed_type = "vertex" ) ) thread.start() From 7e0de753e6633c67b5be347f51cca6a27773665d Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Thu, 21 Dec 2023 16:05:54 -0800 Subject: [PATCH 33/36] feat(EdgeNeighborLoader): update loader and gsql --- pyTigerGraph/gds/dataloaders.py | 90 ++++++----- .../gds/gsql/dataloaders/edge_nei_loader.gsql | 141 +++++++++++------- 2 files changed, 138 insertions(+), 93 deletions(-) diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index 7c2ac792..f2a95745 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -708,8 +708,8 @@ def _read_graph_data( "Spektral is not installed. Please install it to use spektral output." ) # Get raw data from queue and parse - vertex_buffer = [] - edge_buffer = [] + vertex_buffer = dict() + edge_buffer = dict() buffer_size = 0 seeds = set() is_empty = False @@ -724,19 +724,17 @@ def _read_graph_data( if buffer_size > 0: last_batch = True else: - vertex_buffer.extend(raw[0].strip().split("\n ")) - edge_buffer.extend(raw[1].strip().split("\n ")) + vertex_buffer.update({i.strip():"" for i in raw[0].strip().splitlines()}) + edge_buffer.update({i.strip():"" for i in raw[1].strip().splitlines()}) seeds.add(raw[2]) buffer_size += 1 if (buffer_size < batch_size) and (not last_batch): continue try: - vertex_buffer_d = dict.fromkeys(vertex_buffer) - edge_buffer_d = dict.fromkeys(edge_buffer) if seed_type: - raw_data = (vertex_buffer_d.keys(), edge_buffer_d.keys(), seeds) + raw_data = (vertex_buffer.keys(), edge_buffer.keys(), seeds) else: - raw_data = (vertex_buffer_d.keys(), edge_buffer_d.keys()) + raw_data = (vertex_buffer.keys(), edge_buffer.keys()) data = BaseLoader._parse_graph_data_to_df( raw = raw_data, v_in_feats = v_in_feats, @@ -966,7 +964,7 @@ def _parse_vertex_data( # Read in vertex CSVs as dataframes # Each row is in format vid,v_in_feats,v_out_labels,v_extra_feats # or vtype,vid,v_in_feats,v_out_labels,v_extra_feats - v_file = (line.strip().split(delimiter) for line in raw if line) + v_file = (line.split(delimiter) for line in raw) # If seeds are given, create the is_seed column if seeds: seed_df = pd.DataFrame({ @@ -1030,7 +1028,7 @@ def _parse_edge_data( # Read in edge CSVs as dataframes # Each row is in format source_vid,target_vid,e_in_feats,e_out_labels,e_extra_feats # or etype,source_vid,target_vid,e_in_feats,e_out_labels,e_extra_feats - e_file = (line.strip().split(delimiter) for line in raw if line) + e_file = (line.split(delimiter) for line in raw) if not is_hetero: e_attributes = ["source", "target"] + e_in_feats + e_out_labels + e_extra_feats if seeds: @@ -1084,6 +1082,8 @@ def _parse_edge_data( data[etype] = data[etype].merge( tmp_df[["source", "target", "is_seed"]], on=["source", "target"], how="left") data[etype].fillna({"is_seed": False}, inplace=True) + else: + data[etype]["is_seed"] = False return data @staticmethod @@ -3817,7 +3817,8 @@ def _install_query(self, force: bool = False): if self.is_hetero: # Multiple vertex types - print_query = "" + print_query_seed = "" + print_query_other = "" for idx, vtype in enumerate(self._vtypes): v_attr_names = ( self.v_in_feats.get(vtype, []) @@ -3826,17 +3827,25 @@ def _install_query(self, force: bool = False): ) v_attr_types = self._v_schema[vtype] print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query += """ + print_query_seed += """ {} s.type == "{}" THEN - @@v_batch += (s.type + delimiter + stringify(getvid(s)) {}+ "\\n")"""\ + @@v_batch += (s->(s.type + delimiter + stringify(getvid(s)) {}+ "\\n"))"""\ .format("IF" if idx==0 else "ELSE IF", vtype, "+ delimiter + " + print_attr if v_attr_names else "") - print_query += """ + print_query_other += """ + {} s.type == "{}" THEN + @@v_batch += (tmp_seed->(s.type + delimiter + stringify(getvid(s)) {}+ "\\n"))"""\ + .format("IF" if idx==0 else "ELSE IF", vtype, + "+ delimiter + " + print_attr if v_attr_names else "") + print_query_seed += """ + END""" + print_query_other += """ END""" - query_replace["{VERTEXATTRS}"] = print_query + query_replace["{SEEDVERTEXATTRS}"] = print_query_seed + query_replace["{OTHERVERTEXATTRS}"] = print_query_other # Multiple edge types - print_query_seed = "" - print_query_other = "" + print_query = "" + print_query_kafka = "" for idx, etype in enumerate(self._etypes): e_attr_names = ( self.e_in_feats.get(etype, []) @@ -3845,43 +3854,49 @@ def _install_query(self, force: bool = False): ) e_attr_types = self._e_schema[etype] print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) - print_query_seed += """ + print_query += """ {} e.type == "{}" THEN - @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ delimiter + "1\\n")"""\ + @@e_batch += (tmp_seed->(e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ "\\n"))"""\ .format("IF" if idx==0 else "ELSE IF", etype, "+ delimiter + " + print_attr if e_attr_names else "") - print_query_other += """ + print_query_kafka += """ {} e.type == "{}" THEN - @@e_batch += (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ delimiter + "0\\n")"""\ + SET tmp_e = (e.type + delimiter + stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ "\\n", ""), + tmp_e_batch = tmp_e_batch UNION tmp_e"""\ .format("IF" if idx==0 else "ELSE IF", etype, - "+ delimiter + "+ print_attr if e_attr_names else "") - print_query_seed += """ + "+ delimiter + " + print_attr if e_attr_names else "") + print_query += """ END""" - print_query_other += """ + print_query_kafka += """ END""" - query_replace["{SEEDEDGEATTRS}"] = print_query_seed - query_replace["{OTHEREDGEATTRS}"] = print_query_other + query_replace["{EDGEATTRS}"] = print_query + query_replace["{EDGEATTRSKAFKA}"] = print_query_kafka else: # Ignore vertex types v_attr_names = self.v_in_feats + self.v_out_labels + self.v_extra_feats v_attr_types = next(iter(self._v_schema.values())) print_attr = self._generate_attribute_string("vertex", v_attr_names, v_attr_types) - print_query = '@@v_batch += (stringify(getvid(s)) {}+ "\\n")'.format( + print_query_seed = '@@v_batch += (s->(stringify(getvid(s)) {}+ "\\n"))'.format( "+ delimiter + " + print_attr if v_attr_names else "" ) - query_replace["{VERTEXATTRS}"] = print_query + print_query_other = '@@v_batch += (tmp_seed->(stringify(getvid(s)) {}+ "\\n"))'.format( + "+ delimiter + " + print_attr if v_attr_names else "" + ) + query_replace["{SEEDVERTEXATTRS}"] = print_query_seed + query_replace["{OTHERVERTEXATTRS}"] = print_query_other # Ignore edge types e_attr_names = self.e_in_feats + self.e_out_labels + self.e_extra_feats e_attr_types = next(iter(self._e_schema.values())) print_attr = self._generate_attribute_string("edge", e_attr_names, e_attr_types) - print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ delimiter + "1\\n")'.format( + print_query = '@@e_batch += (tmp_seed->(stringify(getvid(s)) + delimiter + stringify(getvid(t)) {} + "\\n"))'.format( "+ delimiter + " + print_attr if e_attr_names else "" ) - query_replace["{SEEDEDGEATTRS}"] = print_query - print_query = '@@e_batch += (stringify(getvid(s)) + delimiter + stringify(getvid(t)) {}+ delimiter + "0\\n")'.format( + query_replace["{EDGEATTRS}"] = print_query + print_query = """SET tmp_e = (stringify(getvid(s)) + delimiter + stringify(getvid(t)) {} + "\\n", ""), + tmp_e_batch = tmp_e_batch UNION tmp_e""".format( "+ delimiter + " + print_attr if e_attr_names else "" ) - query_replace["{OTHEREDGEATTRS}"] = print_query + query_replace["{EDGEATTRSKAFKA}"] = print_query # Install query query_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), @@ -3889,13 +3904,7 @@ def _install_query(self, force: bool = False): "dataloaders", "edge_nei_loader.gsql", ) - sub_query_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "gsql", - "dataloaders", - "edge_nei_loader_sub.gsql", - ) - return install_query_files(self._graph, [sub_query_path, query_path], query_replace, force=force, distributed=[False, self.distributed_query]) + return install_query_file(self._graph, query_path, query_replace, force=force, distributed=self.distributed_query) def _start(self) -> None: # Create task and result queues @@ -3938,7 +3947,8 @@ def _start(self) -> None: add_self_loop = self.add_self_loop, delimiter = self.delimiter, is_hetero = self.is_hetero, - callback_fn = self.callback_fn + callback_fn = self.callback_fn, + seed_type = "edge" ), ) self._reader.start() diff --git a/pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader.gsql b/pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader.gsql index 3a3fc6ab..18788f5e 100644 --- a/pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader.gsql +++ b/pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader.gsql @@ -49,61 +49,123 @@ CREATE QUERY edge_nei_loader_{QUERYSUFFIX}( ssl_ca_location: Path to CA certificate for verifying the Kafka broker key. */ SumAccum @tmp_id; + SumAccum @@kafka_error; + UINT producer; + SetAccum @seeds; start = {v_types}; # Filter seeds if needed - seeds = SELECT s + start = SELECT s FROM start:s -(seed_types:e)- v_types:t WHERE filter_by is NULL OR e.getAttr(filter_by, "BOOL") POST-ACCUM s.@tmp_id = getvid(s) POST-ACCUM t.@tmp_id = getvid(t); # Shuffle vertex ID if needed IF shuffle THEN - INT num_vertices = seeds.size(); + INT num_vertices = start.size(); res = SELECT s - FROM seeds:s + FROM start:s POST-ACCUM s.@tmp_id = floor(rand()*num_vertices) LIMIT 1; END; - # Generate batches # If using kafka to export IF kafka_address != "" THEN - SumAccum @@kafka_error; - # Initialize Kafka producer - UINT producer = init_kafka_producer( + producer = init_kafka_producer( kafka_address, kafka_max_size, security_protocol, sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, ssl_certificate_location, ssl_key_location, ssl_key_password, ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, sasl_kerberos_keytab, sasl_kerberos_principal); + END; + + FOREACH chunk IN RANGE[0, num_chunks-1] DO + MapAccum> @@v_batch; + MapAccum> @@e_batch; - FOREACH chunk IN RANGE[0, num_chunks-1] DO + # Collect neighborhood data for each vertex + seed1 = SELECT s + FROM start:s -(seed_types:e)- v_types:t + WHERE (filter_by IS NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk) + ; + seed2 = SELECT t + FROM start:s -(seed_types:e)- v_types:t + WHERE (filter_by IS NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk) + ; + seeds = seed1 UNION seed2; + seeds = SELECT s + FROM seeds:s + POST-ACCUM + s.@seeds += s, + {SEEDVERTEXATTRS}; + FOREACH hop IN RANGE[1, num_hops] DO + seeds = SELECT t + FROM seeds:s -(e_types:e)- v_types:t + SAMPLE num_neighbors EDGE WHEN s.outdegree() >= 1 + ACCUM + t.@seeds += s.@seeds, + FOREACH tmp_seed in s.@seeds DO + {EDGEATTRS} + END; + seeds = SELECT s + FROM seeds:s + POST-ACCUM + FOREACH tmp_seed in s.@seeds DO + {OTHERVERTEXATTRS} + END; + END; + # Clear all accums + all_v = {v_types}; + res = SELECT s + FROM all_v:s + POST-ACCUM s.@seeds.clear() + LIMIT 1; + + # Generate output for each edge + # If use kafka to export + IF kafka_address != "" THEN res = SELECT s - FROM seeds:s -(seed_types:e)- v_types:t + FROM seed1:s -(seed_types:e)- v_types:t WHERE (filter_by is NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk) ACCUM - STRING e_type = e.type, - LIST msg = edge_nei_loader_sub_{QUERYSUFFIX}(s, t, delimiter, num_hops, num_neighbors, e_types, v_types, e_type), - BOOL is_first=True, - FOREACH i in msg DO - IF is_first THEN - INT kafka_errcode = write_to_kafka(producer, kafka_topic, (getvid(s)+getvid(t))%kafka_topic_partitions, "vertex_batch_" + stringify(getvid(s))+e.type+stringify(getvid(t)), i), - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending vertex batch for " + stringify(getvid(s))+e.type+stringify(getvid(t)) + ": "+ stringify(kafka_errcode) + "\\n") - END, - is_first = False - ELSE - INT kafka_errcode = write_to_kafka(producer, kafka_topic, (getvid(s)+getvid(t))%kafka_topic_partitions, "edge_batch_" + stringify(getvid(s))+e.type+stringify(getvid(t)), i), - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending edge batch for " + stringify(getvid(s))+e.type+stringify(getvid(t)) + ": "+ stringify(kafka_errcode) + "\\n") - END - END + INT part_num = (getvid(s)+getvid(t))%kafka_topic_partitions, + STRING batch_id = stringify(getvid(s))+"_"+e.type+"_"+stringify(getvid(t)), + SET tmp_v_batch = @@v_batch.get(s) + @@v_batch.get(t), + INT kafka_errcode = write_to_kafka(producer, kafka_topic, part_num, "vertex_batch_"+batch_id, stringify(tmp_v_batch)), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending vertex batch for "+batch_id+": "+stringify(kafka_errcode) + "\n") + END, + SET tmp_e_batch = @@e_batch.get(s) + @@e_batch.get(t), + {EDGEATTRSKAFKA}, + kafka_errcode = write_to_kafka(producer, kafka_topic, part_num, "edge_batch_"+batch_id, stringify(tmp_e_batch)), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending edge batch for "+batch_id+ ": "+ stringify(kafka_errcode) + "\n") END LIMIT 1; + # Else return as http response + ELSE + MapAccum @@v_data; + MapAccum @@e_data; + res = SELECT s + FROM seed1:s -(seed_types:e)- v_types:t + WHERE (filter_by is NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk) + ACCUM + STRING batch_id = stringify(getvid(s))+"_"+e.type+"_"+stringify(getvid(t)), + SET tmp_v_batch = @@v_batch.get(s) + @@v_batch.get(t), + @@v_data += (batch_id -> stringify(tmp_v_batch)), + SET tmp_e_batch = @@e_batch.get(s) + @@e_batch.get(t), + {EDGEATTRSKAFKA}, + @@e_data += (batch_id -> stringify(tmp_e_batch)) + LIMIT 1; + + FOREACH (k,v) IN @@v_data DO + PRINT v as vertex_batch, @@e_data.get(k) as edge_batch, k AS seed; + END; END; - + END; + + IF kafka_address != "" THEN FOREACH i IN RANGE[0, kafka_topic_partitions-1] DO INT kafka_errcode = write_to_kafka(producer, kafka_topic, i, "STOP", ""); IF kafka_errcode!=0 THEN @@ -116,32 +178,5 @@ CREATE QUERY edge_nei_loader_{QUERYSUFFIX}( @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n"); END; PRINT @@kafka_error as kafkaError; - # Else return as http response - ELSE - FOREACH chunk IN RANGE[0, num_chunks-1] DO - MapAccum @@v_batch; - MapAccum @@e_batch; - - res = SELECT s - FROM seeds:s -(seed_types:e)- v_types:t - WHERE (filter_by is NULL OR e.getAttr(filter_by, "BOOL")) and ((s.@tmp_id + t.@tmp_id) % num_chunks == chunk) - ACCUM - STRING e_type = e.type, - LIST msg = edge_nei_loader_sub_{QUERYSUFFIX}(s, t, delimiter, num_hops, num_neighbors, e_types, v_types, e_type), - BOOL is_first=True, - FOREACH i in msg DO - IF is_first THEN - @@v_batch += (stringify(getvid(s))+e.type+stringify(getvid(t)) -> i), - is_first = False - ELSE - @@e_batch += (stringify(getvid(s))+e.type+stringify(getvid(t)) -> i) - END - END - LIMIT 1; - - FOREACH (k,v) IN @@v_batch DO - PRINT v as vertex_batch, @@e_batch.get(k) as edge_batch; - END; - END; END; } \ No newline at end of file From 4d334a163adda81948e6837531daa523ceea2529 Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Tue, 2 Jan 2024 13:11:40 -0800 Subject: [PATCH 34/36] fix(parse_edge_data): error when edge type has _ --- pyTigerGraph/gds/dataloaders.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index f2a95745..212d2d4b 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -1044,10 +1044,9 @@ def _parse_edge_data( # If seeds are given, create the is_seed column if seeds: seed_df = pd.DataFrame.from_records( - [i.split("_") for i in seeds], - columns=["source", "etype", "target"]) + [(i.split("_", 1)[0], i.rsplit("_", 1)[-1]) for i in seeds], + columns=["source", "target"]) seed_df["is_seed"] = True - del seed_df["etype"] data = data.merge(seed_df, on=["source", "target"], how="left") data.fillna({"is_seed": False}, inplace=True) else: @@ -1058,7 +1057,7 @@ def _parse_edge_data( # If seeds are given, create the is_seed column if seeds: seed_df = pd.DataFrame.from_records( - [i.split("_") for i in seeds], + [(i.split("_", 1)[0], i.split("_", 1)[1].rsplit("_", 1)[0], i.rsplit("_", 1)[-1]) for i in seeds], columns=["source", "etype", "target"]) seed_df["is_seed"] = True for etype in e_file_dict: From c704e027b31fe65053d3058d43330ed1a5563800 Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Wed, 3 Jan 2024 18:50:59 -0800 Subject: [PATCH 35/36] feat: allow kafka in distributed query --- pyTigerGraph/gds/dataloaders.py | 14 +++- pyTigerGraph/gds/gds.py | 12 +++- .../gds/gsql/dataloaders/edge_nei_loader.gsql | 70 +++++++++++++------ tests/test_gds_EdgeNeighborLoader.py | 32 +++++++++ 4 files changed, 102 insertions(+), 26 deletions(-) diff --git a/pyTigerGraph/gds/dataloaders.py b/pyTigerGraph/gds/dataloaders.py index 212d2d4b..01f5642a 100644 --- a/pyTigerGraph/gds/dataloaders.py +++ b/pyTigerGraph/gds/dataloaders.py @@ -90,7 +90,9 @@ def __init__( kafka_add_topic_per_epoch: bool = False, callback_fn: Callable = None, kafka_group_id: str = None, - kafka_topic: str = None + kafka_topic: str = None, + num_machines: int = 1, + num_segments: int = 20, ) -> None: """Base Class for data loaders. @@ -291,6 +293,8 @@ def __init__( ) # Initialize parameters for the query self._payload = {} + self._payload["num_machines"] = num_machines + self._payload["num_segments"] = num_segments if self.kafka_address_producer: self._payload["kafka_address"] = self.kafka_address_producer self._payload["kafka_topic_partitions"] = kafka_num_partitions @@ -3659,7 +3663,9 @@ def __init__( kafka_add_topic_per_epoch: bool = False, callback_fn: Callable = None, kafka_group_id: str = None, - kafka_topic: str = None + kafka_topic: str = None, + num_machines: int = 1, + num_segments: int = 20 ) -> None: """NO DOC""" @@ -3704,7 +3710,9 @@ def __init__( kafka_add_topic_per_epoch, callback_fn, kafka_group_id, - kafka_topic + kafka_topic, + num_machines, + num_segments ) # Resolve attributes is_hetero = any(map(lambda x: isinstance(x, dict), diff --git a/pyTigerGraph/gds/gds.py b/pyTigerGraph/gds/gds.py index 55d597c1..6a67a123 100644 --- a/pyTigerGraph/gds/gds.py +++ b/pyTigerGraph/gds/gds.py @@ -938,7 +938,9 @@ def edgeNeighborLoader( timeout: int = 300000, callback_fn: Callable = None, reinstall_query: bool = False, - distributed_query: bool = False + distributed_query: bool = False, + num_machines: int = 1, + num_segments: int = 20 ) -> EdgeNeighborLoader: """Returns an `EdgeNeighborLoader` instance. An `EdgeNeighborLoader` instance performs neighbor sampling from all edges in the graph in batches in the following manner: @@ -1098,7 +1100,9 @@ def edgeNeighborLoader( "delimiter": delimiter, "timeout": timeout, "callback_fn": callback_fn, - "distributed_query": distributed_query + "distributed_query": distributed_query, + "num_machines": num_machines, + "num_segments": num_segments } if self.kafkaConfig: params.update(self.kafkaConfig) @@ -1130,7 +1134,9 @@ def edgeNeighborLoader( "delimiter": delimiter, "timeout": timeout, "callback_fn": callback_fn, - "distributed_query": distributed_query + "distributed_query": distributed_query, + "num_machines": num_machines, + "num_segments": num_segments } if self.kafkaConfig: params.update(self.kafkaConfig) diff --git a/pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader.gsql b/pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader.gsql index 18788f5e..b6b2fb08 100644 --- a/pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader.gsql +++ b/pyTigerGraph/gds/gsql/dataloaders/edge_nei_loader.gsql @@ -8,6 +8,8 @@ CREATE QUERY edge_nei_loader_{QUERYSUFFIX}( SET seed_types, STRING delimiter, INT num_chunks=2, + INT num_machines=1, + INT num_segments=20, STRING kafka_address="", STRING kafka_topic="", INT kafka_topic_partitions=1, @@ -50,8 +52,10 @@ CREATE QUERY edge_nei_loader_{QUERYSUFFIX}( */ SumAccum @tmp_id; SumAccum @@kafka_error; - UINT producer; SetAccum @seeds; + MapAccum> @@mid_to_vid; # This tmp accumulator maps machine ID to the smallest vertex ID on the machine. + MapAccum @@mid_to_producer; + SumAccum @kafka_producer_id; start = {v_types}; # Filter seeds if needed @@ -71,13 +75,32 @@ CREATE QUERY edge_nei_loader_{QUERYSUFFIX}( # If using kafka to export IF kafka_address != "" THEN - # Initialize Kafka producer - producer = init_kafka_producer( - kafka_address, kafka_max_size, security_protocol, - sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, - ssl_certificate_location, ssl_key_location, ssl_key_password, - ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, - sasl_kerberos_keytab, sasl_kerberos_principal); + # We generate a vertex set that contains exactly one vertex per machine. + machine_set = + SELECT s + FROM start:s + ACCUM + INT mid = (getvid(s) >> num_segments & 31) % num_machines, + @@mid_to_vid += (mid -> getvid(s)) + HAVING @@mid_to_vid.get((getvid(s) >> num_segments & 31) % num_machines) == getvid(s); + @@mid_to_vid.clear(); + # Initialize Kafka producer per machine + res = SELECT s + FROM machine_set:s + ACCUM + INT mid = (getvid(s) >> num_segments & 31) % num_machines, + UINT producer = init_kafka_producer( + kafka_address, kafka_max_size, security_protocol, + sasl_mechanism, sasl_username, sasl_password, ssl_ca_location, + ssl_certificate_location, ssl_key_location, ssl_key_password, + ssl_endpoint_identification_algorithm, sasl_kerberos_service_name, + sasl_kerberos_keytab, sasl_kerberos_principal), + @@mid_to_producer += (mid -> producer); + res = SELECT s + FROM start:s + ACCUM + INT mid = (getvid(s) >> num_segments & 31) % num_machines, + s.@kafka_producer_id += @@mid_to_producer.get(mid); END; FOREACH chunk IN RANGE[0, num_chunks-1] DO @@ -132,13 +155,13 @@ CREATE QUERY edge_nei_loader_{QUERYSUFFIX}( INT part_num = (getvid(s)+getvid(t))%kafka_topic_partitions, STRING batch_id = stringify(getvid(s))+"_"+e.type+"_"+stringify(getvid(t)), SET tmp_v_batch = @@v_batch.get(s) + @@v_batch.get(t), - INT kafka_errcode = write_to_kafka(producer, kafka_topic, part_num, "vertex_batch_"+batch_id, stringify(tmp_v_batch)), + INT kafka_errcode = write_to_kafka(s.@kafka_producer_id, kafka_topic, part_num, "vertex_batch_"+batch_id, stringify(tmp_v_batch)), IF kafka_errcode!=0 THEN @@kafka_error += ("Error sending vertex batch for "+batch_id+": "+stringify(kafka_errcode) + "\n") END, SET tmp_e_batch = @@e_batch.get(s) + @@e_batch.get(t), {EDGEATTRSKAFKA}, - kafka_errcode = write_to_kafka(producer, kafka_topic, part_num, "edge_batch_"+batch_id, stringify(tmp_e_batch)), + kafka_errcode = write_to_kafka(s.@kafka_producer_id, kafka_topic, part_num, "edge_batch_"+batch_id, stringify(tmp_e_batch)), IF kafka_errcode!=0 THEN @@kafka_error += ("Error sending edge batch for "+batch_id+ ": "+ stringify(kafka_errcode) + "\n") END @@ -166,17 +189,24 @@ CREATE QUERY edge_nei_loader_{QUERYSUFFIX}( END; IF kafka_address != "" THEN - FOREACH i IN RANGE[0, kafka_topic_partitions-1] DO - INT kafka_errcode = write_to_kafka(producer, kafka_topic, i, "STOP", ""); - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error sending STOP signal to topic partition " + stringify(i) + ": " + stringify(kafka_errcode) + "\n"); - END; - END; + res = SELECT s + FROM machine_set:s + WHERE (getvid(s) >> num_segments & 31) % num_machines == 0 + ACCUM + FOREACH i IN RANGE[0, kafka_topic_partitions-1] DO + INT kafka_errcode = write_to_kafka(s.@kafka_producer_id, kafka_topic, i, "STOP", ""), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error sending STOP signal to topic partition " + stringify(i) + ": " + stringify(kafka_errcode) + "\n") + END + END; - INT kafka_errcode = close_kafka_producer(producer, kafka_timeout); - IF kafka_errcode!=0 THEN - @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n"); - END; + res = SELECT s + FROM machine_set:s + ACCUM + INT kafka_errcode = close_kafka_producer(s.@kafka_producer_id, kafka_timeout), + IF kafka_errcode!=0 THEN + @@kafka_error += ("Error shutting down Kafka producer: " + stringify(kafka_errcode) + "\n") + END; PRINT @@kafka_error as kafkaError; END; } \ No newline at end of file diff --git a/tests/test_gds_EdgeNeighborLoader.py b/tests/test_gds_EdgeNeighborLoader.py index 7c503c54..102d07b8 100644 --- a/tests/test_gds_EdgeNeighborLoader.py +++ b/tests/test_gds_EdgeNeighborLoader.py @@ -59,6 +59,37 @@ def test_iterate_pyg(self): self.assertEqual(i, 1024) self.assertLessEqual(batch_sizes[-1], 1024) + def test_iterate_pyg_distributed(self): + loader = EdgeNeighborLoader( + graph=self.conn, + v_in_feats=["x"], + e_extra_feats=["is_train"], + batch_size=1024, + num_neighbors=10, + num_hops=2, + shuffle=True, + filter_by=None, + output_format="PyG", + kafka_address="kafka:9092", + distributed_query=True + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + self.assertIsInstance(data, pygData) + self.assertIn("x", data) + self.assertIn("is_seed", data) + self.assertIn("is_train", data) + self.assertGreater(data["x"].shape[0], 0) + self.assertGreater(data["edge_index"].shape[1], 0) + num_batches += 1 + batch_sizes.append(int(data["is_seed"].sum())) + self.assertEqual(num_batches, 11) + for i in batch_sizes[:-1]: + self.assertEqual(i, 1024) + self.assertLessEqual(batch_sizes[-1], 1024) + def test_sasl_ssl(self): loader = EdgeNeighborLoader( graph=self.conn, @@ -312,6 +343,7 @@ def test_iterate_pyg(self): suite = unittest.TestSuite() suite.addTest(TestGDSEdgeNeighborLoaderKafka("test_init")) suite.addTest(TestGDSEdgeNeighborLoaderKafka("test_iterate_pyg")) + suite.addTest(TestGDSEdgeNeighborLoaderKafka("test_iterate_pyg_distributed")) # suite.addTest(TestGDSEdgeNeighborLoaderKafka("test_sasl_ssl")) suite.addTest(TestGDSEdgeNeighborLoaderREST("test_init")) suite.addTest(TestGDSEdgeNeighborLoaderREST("test_iterate_pyg")) From 0ae8be8b0570f9c3d9125ac5f106a31f3fdaa42d Mon Sep 17 00:00:00 2001 From: Bill Shi Date: Mon, 12 Feb 2024 15:45:34 -0800 Subject: [PATCH 36/36] tests(EdgeNeighborLoader): test distributed query --- tests/test_gds_EdgeNeighborLoader.py | 48 ++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/test_gds_EdgeNeighborLoader.py b/tests/test_gds_EdgeNeighborLoader.py index 102d07b8..96c9e936 100644 --- a/tests/test_gds_EdgeNeighborLoader.py +++ b/tests/test_gds_EdgeNeighborLoader.py @@ -338,6 +338,53 @@ def test_iterate_pyg(self): self.assertEqual(i, 100) self.assertLessEqual(batch_sizes[-1], 100) + def test_iterate_pyg_distributed(self): + loader = EdgeNeighborLoader( + graph=self.conn, + v_in_feats={"v0": ["x", "y"], "v2": ["x"]}, + e_extra_feats={"v2v0":["is_train"], "v0v0":[], "v2v2":[]}, + e_seed_types=["v2v0"], + batch_size=100, + num_neighbors=5, + num_hops=2, + shuffle=True, + filter_by=None, + output_format="PyG", + kafka_address="kafka:9092", + distributed_query=True + ) + num_batches = 0 + batch_sizes = [] + for data in loader: + # print(num_batches, data) + self.assertIsInstance(data, pygHeteroData) + self.assertGreater(data["v0"]["x"].shape[0], 0) + self.assertGreater(data["v2"]["x"].shape[0], 0) + self.assertTrue( + data['v2', 'v2v0', 'v0']["edge_index"].shape[1] > 0 + and data['v2', 'v2v0', 'v0']["edge_index"].shape[1] <= 943 + ) + self.assertEqual( + data['v2', 'v2v0', 'v0']["edge_index"].shape[1], + data['v2', 'v2v0', 'v0']["is_train"].shape[0] + ) + if ('v0', 'v0v0', 'v0') in data.edge_types: + self.assertTrue( + data['v0', 'v0v0', 'v0']["edge_index"].shape[1] > 0 + and data['v0', 'v0v0', 'v0']["edge_index"].shape[1] <= 710 + ) + if ('v2', 'v2v2', 'v2') in data.edge_types: + self.assertTrue( + data['v2', 'v2v2', 'v2']["edge_index"].shape[1] > 0 + and data['v2', 'v2v2', 'v2']["edge_index"].shape[1] <= 966 + ) + num_batches += 1 + batch_sizes.append(int(data['v2', 'v2v0', 'v0']["is_seed"].sum())) + self.assertEqual(num_batches, 10) + for i in batch_sizes[:-1]: + self.assertEqual(i, 100) + self.assertLessEqual(batch_sizes[-1], 100) + if __name__ == "__main__": suite = unittest.TestSuite() @@ -352,5 +399,6 @@ def test_iterate_pyg(self): suite.addTest(TestGDSHeteroEdgeNeighborLoaderREST("test_iterate_pyg")) suite.addTest(TestGDSHeteroEdgeNeighborLoaderKafka("test_init")) suite.addTest(TestGDSHeteroEdgeNeighborLoaderKafka("test_iterate_pyg")) + suite.addTest(TestGDSHeteroEdgeNeighborLoaderKafka("test_iterate_pyg_distributed")) runner = unittest.TextTestRunner(verbosity=2, failfast=True) runner.run(suite)