{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/miniconda3/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "Picard is not available.\n" ] } ], "source": [ "from tqdm.notebook import tqdm\n", "from datasets.load import load_dataset\n", "from transformers.training_args_seq2seq import Seq2SeqTrainingArguments\n", "from seq2seq.utils.args import ModelArguments\n", "from seq2seq.utils.picard_model_wrapper import PicardArguments\n", "from seq2seq.utils.dataset import DataTrainingArguments, DataArguments\n", "from transformers.hf_argparser import HfArgumentParser" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "parser = HfArgumentParser(\n", " (PicardArguments, ModelArguments, DataArguments, DataTrainingArguments, Seq2SeqTrainingArguments)\n", " )" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "picard_args, model_args, data_args, data_training_args, training_args = parser.parse_json_file('./configs/worldcup_train.json')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'/transformers_cache'" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_args.cache_dir" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DataArguments(dataset='spider', dataset_paths={'spider': './seq2seq/datasets/spider', 'cosql': './seq2seq/datasets/cosql', 'spider_realistic': './seq2seq/datasets/spider_realistic', 'spider_syn': './seq2seq/datasets/spider_syn', 'spider_dk': './seq2seq/datasets/spider_dk'}, metric_config='both', metric_paths={'spider': './seq2seq/metrics/spider', 'spider_realistic': './seq2seq/metrics/spider', 'cosql': './seq2seq/metrics/cosql', 'spider_syn': './seq2seq/metrics/spider', 'spider_dk': './seq2seq/metrics/spider'}, test_suite_db_dir=None, data_config_file=None, test_sections=None)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_args" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Reusing dataset spider (./transformers_cache/spider/spider/1.0.0/2ce75fa75bce00ee54968cffd084f961fdb357e6d67f97567e433f61279d35bc)\n", "100%|██████████| 2/2 [00:00<00:00, 6.84it/s]\n" ] } ], "source": [ "ds = load_dataset(\n", " path='./seq2seq/datasets/spider', cache_dir=\"./transformers_cache\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from __future__ import annotations\n", "from typing import Any, Iterable, List, Optional\n", "\n", "from sqlalchemy import MetaData, create_engine, inspect, select, text, func\n", "from sqlalchemy.engine import Engine\n", "from sqlalchemy.exc import ProgrammingError, SQLAlchemyError\n", "from sqlalchemy.schema import CreateTable\n", "from sqlalchemy.orm import sessionmaker\n", "import warnings\n", "from functools import lru_cache\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class SQLDatabase(object):\n", " \"\"\"SQLAlchemy wrapper around a database.\"\"\"\n", "\n", " def __init__(\n", " self,\n", " engine: Engine,\n", " schema: Optional[str] = None,\n", " metadata: Optional[MetaData] = None,\n", " ignore_tables: Optional[List[str]] = None,\n", " include_tables: Optional[List[str]] = None,\n", " sample_rows_in_table_info: int = 3,\n", " indexes_in_table_info: bool = False,\n", " custom_table_info: Optional[dict] = None,\n", " view_support: bool = False,\n", " max_string_length: int = 300,\n", " ):\n", " \"\"\"Create engine from database URI.\"\"\"\n", " self._engine = engine\n", " self._schema = schema\n", " if include_tables and ignore_tables:\n", " raise ValueError(\n", " \"Cannot specify both include_tables and ignore_tables\")\n", "\n", " self._inspector = inspect(self._engine)\n", " # including view support by adding the views as well as tables to the all\n", " # tables list if view_support is True\n", " self._all_tables = set(\n", " self._inspector.get_table_names(schema=schema)\n", " + (self._inspector.get_view_names(schema=schema) if view_support else [])\n", " )\n", " self._include_tables = set(include_tables) if include_tables else set()\n", " if self._include_tables:\n", " missing_tables = self._include_tables - self._all_tables\n", " if missing_tables:\n", " raise ValueError(\n", " f\"include_tables {missing_tables} not found in database\"\n", " )\n", " self._ignore_tables = set(ignore_tables) if ignore_tables else set()\n", " if self._ignore_tables:\n", " missing_tables = self._ignore_tables - self._all_tables\n", " if missing_tables:\n", " raise ValueError(\n", " f\"ignore_tables {missing_tables} not found in database\"\n", " )\n", " usable_tables = self.get_usable_table_names()\n", " self._usable_tables = set(usable_tables) if usable_tables else self._all_tables\n", "\n", " if not isinstance(sample_rows_in_table_info, int):\n", " raise TypeError(\"sample_rows_in_table_info must be an integer\")\n", "\n", " self._sample_rows_in_table_info = sample_rows_in_table_info\n", " self._indexes_in_table_info = indexes_in_table_info\n", "\n", " self._custom_table_info = custom_table_info\n", " if self._custom_table_info:\n", " if not isinstance(self._custom_table_info, dict):\n", " raise TypeError(\n", " \"table_info must be a dictionary with table names as keys and the \"\n", " \"desired table info as values\"\n", " )\n", " # only keep the tables that are also present in the database\n", " intersection = set(self._custom_table_info).intersection(self._all_tables)\n", " self._custom_table_info = dict(\n", " (table, self._custom_table_info[table])\n", " for table in self._custom_table_info\n", " if table in intersection\n", " )\n", "\n", " self._max_string_length = max_string_length\n", " \n", " self._metadata = metadata or MetaData()\n", " # including view support if view_support = true\n", " self._metadata.reflect(\n", " views=view_support,\n", " bind=self._engine,\n", " only=list(self._usable_tables),\n", " schema=self._schema,\n", " )\n", "\n", " @classmethod\n", " def from_uri(\n", " cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any\n", " ) -> SQLDatabase:\n", " \"\"\"Construct a SQLAlchemy engine from URI.\"\"\"\n", " _engine_args = engine_args or {}\n", " return cls(create_engine(database_uri, **_engine_args), **kwargs)\n", "\n", " @classmethod\n", " def from_databricks(\n", " cls,\n", " catalog: str,\n", " schema: str,\n", " host: Optional[str] = None,\n", " api_token: Optional[str] = None,\n", " warehouse_id: Optional[str] = None,\n", " cluster_id: Optional[str] = None,\n", " engine_args: Optional[dict] = None,\n", " **kwargs: Any,\n", " ) -> SQLDatabase:\n", " \"\"\"\n", " Class method to create an SQLDatabase instance from a Databricks connection.\n", " This method requires the 'databricks-sql-connector' package. If not installed,\n", " it can be added using `pip install databricks-sql-connector`.\n", "\n", " Args:\n", " catalog (str): The catalog name in the Databricks database.\n", " schema (str): The schema name in the catalog.\n", " host (Optional[str]): The Databricks workspace hostname, excluding\n", " 'https://' part. If not provided, it attempts to fetch from the\n", " environment variable 'DATABRICKS_HOST'. If still unavailable and if\n", " running in a Databricks notebook, it defaults to the current workspace\n", " hostname. Defaults to None.\n", " api_token (Optional[str]): The Databricks personal access token for\n", " accessing the Databricks SQL warehouse or the cluster. If not provided,\n", " it attempts to fetch from 'DATABRICKS_TOKEN'. If still unavailable\n", " and running in a Databricks notebook, a temporary token for the current\n", " user is generated. Defaults to None.\n", " warehouse_id (Optional[str]): The warehouse ID in the Databricks SQL. If\n", " provided, the method configures the connection to use this warehouse.\n", " Cannot be used with 'cluster_id'. Defaults to None.\n", " cluster_id (Optional[str]): The cluster ID in the Databricks Runtime. If\n", " provided, the method configures the connection to use this cluster.\n", " Cannot be used with 'warehouse_id'. If running in a Databricks notebook\n", " and both 'warehouse_id' and 'cluster_id' are None, it uses the ID of the\n", " cluster the notebook is attached to. Defaults to None.\n", " engine_args (Optional[dict]): The arguments to be used when connecting\n", " Databricks. Defaults to None.\n", " **kwargs (Any): Additional keyword arguments for the `from_uri` method.\n", "\n", " Returns:\n", " SQLDatabase: An instance of SQLDatabase configured with the provided\n", " Databricks connection details.\n", "\n", " Raises:\n", " ValueError: If 'databricks-sql-connector' is not found, or if both\n", " 'warehouse_id' and 'cluster_id' are provided, or if neither\n", " 'warehouse_id' nor 'cluster_id' are provided and it's not executing\n", " inside a Databricks notebook.\n", " \"\"\"\n", " try:\n", " from databricks import sql # noqa: F401\n", " except ImportError:\n", " raise ValueError(\n", " \"databricks-sql-connector package not found, please install with\"\n", " \" `pip install databricks-sql-connector`\"\n", " )\n", " context = None\n", " try:\n", " from dbruntime.databricks_repl_context import get_context\n", "\n", " context = get_context()\n", " except ImportError:\n", " pass\n", "\n", " default_host = context.browserHostName if context else None\n", " if host is None:\n", " host = tools.get_from_env(\"host\", \"DATABRICKS_HOST\", default_host)\n", "\n", " default_api_token = context.apiToken if context else None\n", " if api_token is None:\n", " api_token = tools.get_from_env(\n", " \"api_token\", \"DATABRICKS_TOKEN\", default_api_token\n", " )\n", "\n", " if warehouse_id is None and cluster_id is None:\n", " if context:\n", " cluster_id = context.clusterId\n", " else:\n", " raise ValueError(\n", " \"Need to provide either 'warehouse_id' or 'cluster_id'.\"\n", " )\n", "\n", " if warehouse_id and cluster_id:\n", " raise ValueError(\"Can't have both 'warehouse_id' or 'cluster_id'.\")\n", "\n", " if warehouse_id:\n", " http_path = f\"/sql/1.0/warehouses/{warehouse_id}\"\n", " else:\n", " http_path = f\"/sql/protocolv1/o/0/{cluster_id}\"\n", "\n", " uri = (\n", " f\"databricks://token:{api_token}@{host}?\"\n", " f\"http_path={http_path}&catalog={catalog}&schema={schema}\"\n", " )\n", " return cls.from_uri(database_uri=uri, engine_args=engine_args, **kwargs)\n", "\n", " @property\n", " def dialect(self) -> str:\n", " \"\"\"Return string representation of dialect to use.\"\"\"\n", " return self._engine.dialect.name\n", "\n", " def get_usable_table_names(self) -> Iterable[str]:\n", " \"\"\"Get names of tables available.\"\"\"\n", " if self._include_tables:\n", " return self._include_tables\n", " return self._all_tables - self._ignore_tables\n", " \n", " def get_table_names(self) -> Iterable[str]:\n", " \"\"\"Get names of tables available.\"\"\"\n", " warnings.warn(\n", " \"This method is deprecated - please use `get_usable_table_names`.\"\n", " )\n", " return self.get_usable_table_names()\n", " \n", " @property\n", " def table_info(self) -> str:\n", " \"\"\"Information about all tables in the database.\"\"\"\n", " return self.get_table_info()\n", "\n", " def get_table_info(self, table_names: Optional[List[str]] = None) -> str:\n", " \"\"\"Get information about specified tables.\n", "\n", " Follows best practices as specified in: Rajkumar et al, 2022\n", " (https://arxiv.org/abs/2204.00498)\n", "\n", " If `sample_rows_in_table_info`, the specified number of sample rows will be\n", " appended to each table description. This can increase performance as\n", " demonstrated in the paper.\n", " \"\"\"\n", " all_table_names = self.get_usable_table_names()\n", " if table_names is not None:\n", " missing_tables = set(table_names).difference(all_table_names)\n", " if missing_tables:\n", " raise ValueError(f\"table_names {missing_tables} not found in database\")\n", " all_table_names = table_names\n", "\n", " meta_tables = [\n", " tbl\n", " for tbl in self._metadata.sorted_tables\n", " if tbl.name in set(all_table_names)\n", " and not (self.dialect == \"sqlite\" and tbl.name.startswith(\"sqlite_\"))\n", " ]\n", "\n", " tables = []\n", " for table in meta_tables:\n", " if self._custom_table_info and table.name in self._custom_table_info:\n", " tables.append(self._custom_table_info[table.name])\n", " continue\n", "\n", " # add create table command\n", " create_table = str(CreateTable(table).compile(self._engine))\n", " table_info = f\"{create_table.rstrip()}\"\n", " has_extra_info = (\n", " self._indexes_in_table_info or self._sample_rows_in_table_info\n", " )\n", " if has_extra_info:\n", " table_info += \"\\n\\n/*\"\n", " if self._indexes_in_table_info:\n", " table_info += f\"\\n{self._get_table_indexes(table)}\\n\"\n", " if self._sample_rows_in_table_info:\n", " table_info += f\"\\n{self._get_sample_rows(table)}\\n\"\n", " if has_extra_info:\n", " table_info += \"*/\"\n", " tables.append(table_info)\n", " final_str = \"\\n\\n\".join(tables)\n", " return final_str\n", " \n", " def _get_table_indexes(self, table: Table) -> str:\n", " indexes = self._inspector.get_indexes(table.name)\n", " indexes_formatted = \"\\n\".join(map(_format_index, indexes))\n", " return f\"Table Indexes:\\n{indexes_formatted}\"\n", " \n", " def _get_sample_rows(self, table: Table) -> str:\n", "\n", " # build the select command\n", " command = select(table).limit(self._sample_rows_in_table_info)\n", "\n", " # save the command in string format\n", " select_star = (\n", " f\"SELECT * FROM '{table.name}' LIMIT \"\n", " f\"{self._sample_rows_in_table_info}\"\n", " )\n", "\n", " # save the columns in string format\n", " columns_str = \"\\t\".join([col.name for col in table.columns])\n", "\n", " try:\n", " # get the sample rows\n", " with self._engine.connect() as connection:\n", " try:\n", " sample_rows_result = connection.execute(command)\n", " # shorten values in the sample rows\n", " \n", " sample_rows = list(\n", " map(lambda ls: [str(i)[:100]\n", " for i in ls], sample_rows_result)\n", " )\n", " except TypeError as e:\n", " # print(\"***Back to literal querying...***\")\n", " sample_rows_result = connection.exec_driver_sql(select_star)\n", " # shorten values in the sample rows\n", " sample_rows = list(\n", " map(lambda ls: [str(i)[:100]\n", " for i in ls], sample_rows_result)\n", " )\n", "\n", " # save the sample rows in string format\n", " sample_rows_str = \"\\n\".join([\"\\t\".join(row) for row in sample_rows])\n", " # in some dialects when there are no rows in the table a\n", " # 'ProgrammingError' is returned\n", " except ProgrammingError:\n", " sample_rows_str = \"\"\n", "\n", " return (\n", " f\"{self._sample_rows_in_table_info} rows from {table.name} table:\\n\"\n", " f\"{columns_str}\\n\"\n", " f\"{sample_rows_str}\"\n", " )\n", "\n", " \n", " def get_table_info_dict(self, table_names: Optional[List[str]] = None, with_col_details = True, do_sampling = True) -> dict:\n", " all_table_names = self.get_usable_table_names()\n", " if table_names is not None:\n", " missing_tables = set(table_names).difference(all_table_names)\n", " if missing_tables:\n", " raise ValueError(f\"table_names {missing_tables} not found in database\")\n", " all_table_names = table_names\n", "\n", " meta_tables = [\n", " tbl\n", " for tbl in self._metadata.sorted_tables\n", " if tbl.name in set(all_table_names)\n", " and not (self.dialect == \"sqlite\" and tbl.name.startswith(\"sqlite_\"))\n", " ]\n", "\n", " tables = []\n", " \"\"\"\n", " all_table_names = self.get_table_names()\n", " if table_names is not None:\n", " missing_tables = set(table_names).difference(all_table_names)\n", " if missing_tables:\n", " raise ValueError(\n", " f\"table_names {missing_tables} not found in database\")\n", " all_table_names = table_names\n", "\n", " meta_tables = [\n", " tbl\n", " for tbl in self._metadata.sorted_tables\n", " if tbl.name in set(all_table_names)\n", " and not (self.dialect == \"sqlite\" and tbl.name.startswith(\"sqlite_\"))\n", " ]\n", "\n", " tables = []\n", " \"\"\"\n", "\n", " for table in meta_tables:\n", " if self._custom_table_info and table.name in self._custom_table_info:\n", " tables.append(self._custom_table_info[table.name])\n", " else:\n", " tables.append(table)\n", " tables_dict = {}\n", " for table in tables:\n", " cols = []\n", " col_details = []\n", " pk = []\n", " fks = []\n", " sample_rows = []\n", " num_rows = 0\n", " if do_sampling:\n", " sample_rows = self.get_tbl_samples_dict(table)\n", " num_rows = self.get_rows_of_a_table(table)\n", " for col in table.columns:\n", " cols.append([col.name, str(col.type).split('.')[-1]])\n", " if col.primary_key:\n", " pk.append(col.name)\n", " if len(col.foreign_keys) > 0:\n", " for fk in list(col.foreign_keys):\n", " fks.append([f'{table.name}.{col.name}', '.'.join(fk.target_fullname.split('.')[-2:])])\n", " \n", " if with_col_details and num_rows > 0:\n", " distinct_values = self.count_distinct_values_of_a_col(table, col)\n", " cardinality = len(distinct_values) / num_rows\n", " # here we use 3 simple conditions to filterout the categorical values:\n", " # 1. cardinality < 0.3\n", " # 2. total len(distinct_values) < 20\n", " # 3. '_id' not in name or name is not equal to 'id'\n", " if cardinality < 0.5 and len(distinct_values) < 20 and ('_id' not in col.name.lower() or col.name.lower() == 'id'): # maybe a categorical value\n", " col_details.append({'is_categorical': True, 'cardinality': cardinality, 'distinct_values': distinct_values})\n", " else:\n", " col_details.append({'is_categorical': False, 'cardinality': cardinality, 'distinct_values': distinct_values[:20]})\n", " \n", " tables_dict[table.name] = {\n", " 'COL': cols,\n", " 'PK': pk,\n", " 'FK': fks\n", " }\n", " if do_sampling:\n", " tables_dict[table.name]['sample_rows'] = sample_rows\n", " if with_col_details:\n", " tables_dict[table.name]['COL_DETAILS'] = col_details \n", " return tables_dict\n", " \n", " def get_tbl_samples_dict(self, table):\n", " sample_rows_dict = {}\n", " if self._sample_rows_in_table_info:\n", " # build the select command\n", " command = select(table).limit(self._sample_rows_in_table_info)\n", "\n", " # save the command in string format\n", " select_star = (\n", " f\"SELECT * FROM '{table.name}' LIMIT \"\n", " f\"{self._sample_rows_in_table_info}\"\n", " )\n", "\n", " # save the columns\n", " columns = [col.name for col in table.columns]\n", "\n", " # get the sample rows\n", " try:\n", " with self._engine.connect() as connection:\n", " try:\n", " sample_rows = connection.execute(command)\n", " # shorten values in the sample rows\n", " \n", " sample_rows = list(\n", " map(lambda ls: [str(i)[:100]\n", " for i in ls], sample_rows)\n", " )\n", " except TypeError as e:\n", " # print(\"***Back to literal querying...***\")\n", " sample_rows = connection.exec_driver_sql(select_star)\n", " # shorten values in the sample rows\n", " sample_rows = list(\n", " map(lambda ls: [str(i)[:100]\n", " for i in ls], sample_rows)\n", " )\n", " sample_rows_T = list(map(list, zip(*sample_rows)))\n", " for col, rows in zip(columns, sample_rows_T):\n", " sample_rows_dict[col] = rows\n", " except ProgrammingError:\n", " print('Warning: sampling error')\n", " sample_rows_dict = {}\n", " return sample_rows_dict\n", "\n", " def get_rows_of_a_table(self, table):\n", " command = select(func.count()).select_from(table)\n", " try:\n", " with self._engine.connect() as connection:\n", " num_rows = connection.execute(command)\n", " # print(table.name)\n", " return num_rows.scalar()\n", " except ProgrammingError:\n", " warnings.warn('Fetching categorical values error')\n", " return None\n", " \n", " def count_distinct_values_of_a_col(self, table, column, num_limit=100):\n", " command = select(func.count(column), column).group_by(column).order_by(func.count(column).desc()).limit(num_limit)\n", " try:\n", " with self._engine.connect() as connection:\n", " try:\n", " sample_rows = connection.execute(command).fetchall()\n", " # print(table.name, column.name)\n", " return [list(r) for r in sample_rows]\n", " except ValueError as e:\n", " print(f\"ValueError: {e.__traceback__}\")\n", " # backdraw to use exec_driver_sql method\n", " select_str = (\n", " f\"SELECT COUNT(*) FROM '{table.name}' GROUP BY {column.name} ORDER BY COUNT(*) LIMIT {num_limit}\")\n", " sample_rows = connection.exec_driver_sql(select_str).fetchall()\n", " return [list(r) for r in sample_rows]\n", " except ProgrammingError:\n", " print('Warning: categorical error')\n", " return []\n", " \n", " def run(self, command: str, fetch: str = \"all\", fmt = \"str\") -> str:\n", " \"\"\"Execute a SQL command and return a string representing the results.\n", "\n", " If the statement returns rows, a string of the results is returned.\n", " If the statement returns no rows, an empty string is returned.\n", " \"\"\"\n", " with self._engine.begin() as connection:\n", " if self._schema is not None:\n", " if self.dialect == \"snowflake\":\n", " connection.exec_driver_sql(\n", " f\"ALTER SESSION SET search_path='{self._schema}'\"\n", " )\n", " elif self.dialect == \"bigquery\":\n", " connection.exec_driver_sql(f\"SET @@dataset_id='{self._schema}'\")\n", " else:\n", " connection.exec_driver_sql(f\"SET search_path TO {self._schema}\")\n", " cursor = connection.execute(text(command))\n", " if cursor.returns_rows:\n", " if fetch == \"all\":\n", " result = cursor.fetchall()\n", " elif fetch == \"one\":\n", " result = cursor.fetchone()[0]\n", " else:\n", " raise ValueError(\n", " \"Fetch parameter must be either 'one' or 'all'\")\n", " if fmt == \"str\":\n", " return str(result)\n", " elif fmt == \"list\":\n", " return list(result)\n", " return \"\"\n", "\n", " def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:\n", " \"\"\"Get information about specified tables.\n", "\n", " Follows best practices as specified in: Rajkumar et al, 2022\n", " (https://arxiv.org/abs/2204.00498)\n", "\n", " If `sample_rows_in_table_info`, the specified number of sample rows will be\n", " appended to each table description. This can increase performance as\n", " demonstrated in the paper.\n", " \"\"\"\n", " try:\n", " return self.get_table_info(table_names)\n", " except ValueError as e:\n", " \"\"\"Format the error message\"\"\"\n", " return f\"Error: {e}\"\n", "\n", " def run_no_throw(self, command: str, fetch: str = \"all\") -> str:\n", " \"\"\"Execute a SQL command and return a string representing the results.\n", "\n", " If the statement returns rows, a string of the results is returned.\n", " If the statement returns no rows, an empty string is returned.\n", "\n", " If the statement throws an error, the error message is returned.\n", " \"\"\"\n", " try:\n", " return self.run(command, fetch)\n", " except SQLAlchemyError as e:\n", " \"\"\"Format the error message\"\"\"\n", " return f\"Error: {e}\"\n", " \n", " def dict2str(self, d):\n", " text = []\n", " for t,v in d.items():\n", " _tbl = f'{t}:'\n", " cols = []\n", " pks = ['PK:']\n", " fks = ['FK:']\n", " for col in v['COL']:\n", " cols.append(f'{col[0]}:{self.aliastype(col[1])}')\n", " for pk in v['PK']:\n", " pks.append(pk)\n", " for fk in v['FK']:\n", " fks.append('='.join(list(fk)))\n", " \n", " tbl = '\\n'.join([_tbl, ', '.join(cols), ' '.join(pks), ' '.join(fks)])\n", " text.append(tbl)\n", " return '\\n'.join(text)\n", " \n", " def aliastype(self, t):\n", " _t = t[:3].lower()\n", " if _t in ['int', 'tin', 'sma', 'med', 'big','uns', 'rea', 'dou', 'num', 'dec', 'tim']:\n", " res = 'N' # numerical value\n", " elif _t in ['tex', 'var', 'cha', 'nch', 'nat', 'nva', 'clo']:\n", " res = 'T' # text value\n", " elif _t in ['boo']:\n", " res = 'B'\n", " elif _t in ['dat']:\n", " res = 'D'\n", " else:\n", " raise ValueError('Unsupported data type')\n", " return res\n", " \n", " @lru_cache(maxsize=1000)\n", " def cached_query_results(self, query, limit_num=100):\n", " session = self.Session()\n", " try:\n", " result = session.execute(query).fetchmany(limit_num)\n", " return result\n", " except Exception as e:\n", " print(f\"Error executing query: {query}. Error: {str(e)}\")\n", " raise\n", " finally:\n", " session.close()\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "db.get_table_info_dict(do_sampling=False, with_col_details=False)" ] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.16" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }