nlql / playground_datasets.ipynb
playground_datasets.ipynb
Raw
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/miniconda3/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Picard is not available.\n"
     ]
    }
   ],
   "source": [
    "from tqdm.notebook import tqdm\n",
    "from datasets.load import load_dataset\n",
    "from transformers.training_args_seq2seq import Seq2SeqTrainingArguments\n",
    "from seq2seq.utils.args import ModelArguments\n",
    "from seq2seq.utils.picard_model_wrapper import PicardArguments\n",
    "from seq2seq.utils.dataset import DataTrainingArguments, DataArguments\n",
    "from transformers.hf_argparser import HfArgumentParser"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "parser = HfArgumentParser(\n",
    "        (PicardArguments, ModelArguments, DataArguments, DataTrainingArguments, Seq2SeqTrainingArguments)\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "picard_args, model_args, data_args, data_training_args, training_args = parser.parse_json_file('./configs/worldcup_train.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/transformers_cache'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_args.cache_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DataArguments(dataset='spider', dataset_paths={'spider': './seq2seq/datasets/spider', 'cosql': './seq2seq/datasets/cosql', 'spider_realistic': './seq2seq/datasets/spider_realistic', 'spider_syn': './seq2seq/datasets/spider_syn', 'spider_dk': './seq2seq/datasets/spider_dk'}, metric_config='both', metric_paths={'spider': './seq2seq/metrics/spider', 'spider_realistic': './seq2seq/metrics/spider', 'cosql': './seq2seq/metrics/cosql', 'spider_syn': './seq2seq/metrics/spider', 'spider_dk': './seq2seq/metrics/spider'}, test_suite_db_dir=None, data_config_file=None, test_sections=None)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Reusing dataset spider (./transformers_cache/spider/spider/1.0.0/2ce75fa75bce00ee54968cffd084f961fdb357e6d67f97567e433f61279d35bc)\n",
      "100%|██████████| 2/2 [00:00<00:00,  6.84it/s]\n"
     ]
    }
   ],
   "source": [
    "ds = load_dataset(\n",
    "        path='./seq2seq/datasets/spider', cache_dir=\"./transformers_cache\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import annotations\n",
    "from typing import Any, Iterable, List, Optional\n",
    "\n",
    "from sqlalchemy import MetaData, create_engine, inspect, select, text, func\n",
    "from sqlalchemy.engine import Engine\n",
    "from sqlalchemy.exc import ProgrammingError, SQLAlchemyError\n",
    "from sqlalchemy.schema import CreateTable\n",
    "from sqlalchemy.orm import sessionmaker\n",
    "import warnings\n",
    "from functools import lru_cache\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SQLDatabase(object):\n",
    "    \"\"\"SQLAlchemy wrapper around a database.\"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        engine: Engine,\n",
    "        schema: Optional[str] = None,\n",
    "        metadata: Optional[MetaData] = None,\n",
    "        ignore_tables: Optional[List[str]] = None,\n",
    "        include_tables: Optional[List[str]] = None,\n",
    "        sample_rows_in_table_info: int = 3,\n",
    "        indexes_in_table_info: bool = False,\n",
    "        custom_table_info: Optional[dict] = None,\n",
    "        view_support: bool = False,\n",
    "        max_string_length: int = 300,\n",
    "    ):\n",
    "        \"\"\"Create engine from database URI.\"\"\"\n",
    "        self._engine = engine\n",
    "        self._schema = schema\n",
    "        if include_tables and ignore_tables:\n",
    "            raise ValueError(\n",
    "                \"Cannot specify both include_tables and ignore_tables\")\n",
    "\n",
    "        self._inspector = inspect(self._engine)\n",
    "        # including view support by adding the views as well as tables to the all\n",
    "        # tables list if view_support is True\n",
    "        self._all_tables = set(\n",
    "            self._inspector.get_table_names(schema=schema)\n",
    "            + (self._inspector.get_view_names(schema=schema) if view_support else [])\n",
    "        )\n",
    "        self._include_tables = set(include_tables) if include_tables else set()\n",
    "        if self._include_tables:\n",
    "            missing_tables = self._include_tables - self._all_tables\n",
    "            if missing_tables:\n",
    "                raise ValueError(\n",
    "                    f\"include_tables {missing_tables} not found in database\"\n",
    "                )\n",
    "        self._ignore_tables = set(ignore_tables) if ignore_tables else set()\n",
    "        if self._ignore_tables:\n",
    "            missing_tables = self._ignore_tables - self._all_tables\n",
    "            if missing_tables:\n",
    "                raise ValueError(\n",
    "                    f\"ignore_tables {missing_tables} not found in database\"\n",
    "                )\n",
    "        usable_tables = self.get_usable_table_names()\n",
    "        self._usable_tables = set(usable_tables) if usable_tables else self._all_tables\n",
    "\n",
    "        if not isinstance(sample_rows_in_table_info, int):\n",
    "            raise TypeError(\"sample_rows_in_table_info must be an integer\")\n",
    "\n",
    "        self._sample_rows_in_table_info = sample_rows_in_table_info\n",
    "        self._indexes_in_table_info = indexes_in_table_info\n",
    "\n",
    "        self._custom_table_info = custom_table_info\n",
    "        if self._custom_table_info:\n",
    "            if not isinstance(self._custom_table_info, dict):\n",
    "                raise TypeError(\n",
    "                    \"table_info must be a dictionary with table names as keys and the \"\n",
    "                    \"desired table info as values\"\n",
    "                )\n",
    "            # only keep the tables that are also present in the database\n",
    "            intersection = set(self._custom_table_info).intersection(self._all_tables)\n",
    "            self._custom_table_info = dict(\n",
    "                (table, self._custom_table_info[table])\n",
    "                for table in self._custom_table_info\n",
    "                if table in intersection\n",
    "            )\n",
    "\n",
    "        self._max_string_length = max_string_length\n",
    "        \n",
    "        self._metadata = metadata or MetaData()\n",
    "        # including view support if view_support = true\n",
    "        self._metadata.reflect(\n",
    "            views=view_support,\n",
    "            bind=self._engine,\n",
    "            only=list(self._usable_tables),\n",
    "            schema=self._schema,\n",
    "            )\n",
    "\n",
    "    @classmethod\n",
    "    def from_uri(\n",
    "        cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any\n",
    "    ) -> SQLDatabase:\n",
    "        \"\"\"Construct a SQLAlchemy engine from URI.\"\"\"\n",
    "        _engine_args = engine_args or {}\n",
    "        return cls(create_engine(database_uri, **_engine_args), **kwargs)\n",
    "\n",
    "    @classmethod\n",
    "    def from_databricks(\n",
    "        cls,\n",
    "        catalog: str,\n",
    "        schema: str,\n",
    "        host: Optional[str] = None,\n",
    "        api_token: Optional[str] = None,\n",
    "        warehouse_id: Optional[str] = None,\n",
    "        cluster_id: Optional[str] = None,\n",
    "        engine_args: Optional[dict] = None,\n",
    "        **kwargs: Any,\n",
    "    ) -> SQLDatabase:\n",
    "        \"\"\"\n",
    "        Class method to create an SQLDatabase instance from a Databricks connection.\n",
    "        This method requires the 'databricks-sql-connector' package. If not installed,\n",
    "        it can be added using `pip install databricks-sql-connector`.\n",
    "\n",
    "        Args:\n",
    "            catalog (str): The catalog name in the Databricks database.\n",
    "            schema (str): The schema name in the catalog.\n",
    "            host (Optional[str]): The Databricks workspace hostname, excluding\n",
    "                'https://' part. If not provided, it attempts to fetch from the\n",
    "                environment variable 'DATABRICKS_HOST'. If still unavailable and if\n",
    "                running in a Databricks notebook, it defaults to the current workspace\n",
    "                hostname. Defaults to None.\n",
    "            api_token (Optional[str]): The Databricks personal access token for\n",
    "                accessing the Databricks SQL warehouse or the cluster. If not provided,\n",
    "                it attempts to fetch from 'DATABRICKS_TOKEN'. If still unavailable\n",
    "                and running in a Databricks notebook, a temporary token for the current\n",
    "                user is generated. Defaults to None.\n",
    "            warehouse_id (Optional[str]): The warehouse ID in the Databricks SQL. If\n",
    "                provided, the method configures the connection to use this warehouse.\n",
    "                Cannot be used with 'cluster_id'. Defaults to None.\n",
    "            cluster_id (Optional[str]): The cluster ID in the Databricks Runtime. If\n",
    "                provided, the method configures the connection to use this cluster.\n",
    "                Cannot be used with 'warehouse_id'. If running in a Databricks notebook\n",
    "                and both 'warehouse_id' and 'cluster_id' are None, it uses the ID of the\n",
    "                cluster the notebook is attached to. Defaults to None.\n",
    "            engine_args (Optional[dict]): The arguments to be used when connecting\n",
    "                Databricks. Defaults to None.\n",
    "            **kwargs (Any): Additional keyword arguments for the `from_uri` method.\n",
    "\n",
    "        Returns:\n",
    "            SQLDatabase: An instance of SQLDatabase configured with the provided\n",
    "                Databricks connection details.\n",
    "\n",
    "        Raises:\n",
    "            ValueError: If 'databricks-sql-connector' is not found, or if both\n",
    "                'warehouse_id' and 'cluster_id' are provided, or if neither\n",
    "                'warehouse_id' nor 'cluster_id' are provided and it's not executing\n",
    "                inside a Databricks notebook.\n",
    "        \"\"\"\n",
    "        try:\n",
    "            from databricks import sql  # noqa: F401\n",
    "        except ImportError:\n",
    "            raise ValueError(\n",
    "                \"databricks-sql-connector package not found, please install with\"\n",
    "                \" `pip install databricks-sql-connector`\"\n",
    "            )\n",
    "        context = None\n",
    "        try:\n",
    "            from dbruntime.databricks_repl_context import get_context\n",
    "\n",
    "            context = get_context()\n",
    "        except ImportError:\n",
    "            pass\n",
    "\n",
    "        default_host = context.browserHostName if context else None\n",
    "        if host is None:\n",
    "            host = tools.get_from_env(\"host\", \"DATABRICKS_HOST\", default_host)\n",
    "\n",
    "        default_api_token = context.apiToken if context else None\n",
    "        if api_token is None:\n",
    "            api_token = tools.get_from_env(\n",
    "                \"api_token\", \"DATABRICKS_TOKEN\", default_api_token\n",
    "            )\n",
    "\n",
    "        if warehouse_id is None and cluster_id is None:\n",
    "            if context:\n",
    "                cluster_id = context.clusterId\n",
    "            else:\n",
    "                raise ValueError(\n",
    "                    \"Need to provide either 'warehouse_id' or 'cluster_id'.\"\n",
    "                )\n",
    "\n",
    "        if warehouse_id and cluster_id:\n",
    "            raise ValueError(\"Can't have both 'warehouse_id' or 'cluster_id'.\")\n",
    "\n",
    "        if warehouse_id:\n",
    "            http_path = f\"/sql/1.0/warehouses/{warehouse_id}\"\n",
    "        else:\n",
    "            http_path = f\"/sql/protocolv1/o/0/{cluster_id}\"\n",
    "\n",
    "        uri = (\n",
    "            f\"databricks://token:{api_token}@{host}?\"\n",
    "            f\"http_path={http_path}&catalog={catalog}&schema={schema}\"\n",
    "        )\n",
    "        return cls.from_uri(database_uri=uri, engine_args=engine_args, **kwargs)\n",
    "\n",
    "    @property\n",
    "    def dialect(self) -> str:\n",
    "        \"\"\"Return string representation of dialect to use.\"\"\"\n",
    "        return self._engine.dialect.name\n",
    "\n",
    "    def get_usable_table_names(self) -> Iterable[str]:\n",
    "        \"\"\"Get names of tables available.\"\"\"\n",
    "        if self._include_tables:\n",
    "            return self._include_tables\n",
    "        return self._all_tables - self._ignore_tables\n",
    "    \n",
    "    def get_table_names(self) -> Iterable[str]:\n",
    "        \"\"\"Get names of tables available.\"\"\"\n",
    "        warnings.warn(\n",
    "            \"This method is deprecated - please use `get_usable_table_names`.\"\n",
    "        )\n",
    "        return self.get_usable_table_names()\n",
    "    \n",
    "    @property\n",
    "    def table_info(self) -> str:\n",
    "        \"\"\"Information about all tables in the database.\"\"\"\n",
    "        return self.get_table_info()\n",
    "\n",
    "    def get_table_info(self, table_names: Optional[List[str]] = None) -> str:\n",
    "        \"\"\"Get information about specified tables.\n",
    "\n",
    "        Follows best practices as specified in: Rajkumar et al, 2022\n",
    "        (https://arxiv.org/abs/2204.00498)\n",
    "\n",
    "        If `sample_rows_in_table_info`, the specified number of sample rows will be\n",
    "        appended to each table description. This can increase performance as\n",
    "        demonstrated in the paper.\n",
    "        \"\"\"\n",
    "        all_table_names = self.get_usable_table_names()\n",
    "        if table_names is not None:\n",
    "            missing_tables = set(table_names).difference(all_table_names)\n",
    "            if missing_tables:\n",
    "                raise ValueError(f\"table_names {missing_tables} not found in database\")\n",
    "            all_table_names = table_names\n",
    "\n",
    "        meta_tables = [\n",
    "            tbl\n",
    "            for tbl in self._metadata.sorted_tables\n",
    "            if tbl.name in set(all_table_names)\n",
    "            and not (self.dialect == \"sqlite\" and tbl.name.startswith(\"sqlite_\"))\n",
    "        ]\n",
    "\n",
    "        tables = []\n",
    "        for table in meta_tables:\n",
    "            if self._custom_table_info and table.name in self._custom_table_info:\n",
    "                tables.append(self._custom_table_info[table.name])\n",
    "                continue\n",
    "\n",
    "            # add create table command\n",
    "            create_table = str(CreateTable(table).compile(self._engine))\n",
    "            table_info = f\"{create_table.rstrip()}\"\n",
    "            has_extra_info = (\n",
    "                self._indexes_in_table_info or self._sample_rows_in_table_info\n",
    "            )\n",
    "            if has_extra_info:\n",
    "                table_info += \"\\n\\n/*\"\n",
    "            if self._indexes_in_table_info:\n",
    "                table_info += f\"\\n{self._get_table_indexes(table)}\\n\"\n",
    "            if self._sample_rows_in_table_info:\n",
    "                table_info += f\"\\n{self._get_sample_rows(table)}\\n\"\n",
    "            if has_extra_info:\n",
    "                table_info += \"*/\"\n",
    "            tables.append(table_info)\n",
    "        final_str = \"\\n\\n\".join(tables)\n",
    "        return final_str\n",
    "    \n",
    "    def _get_table_indexes(self, table: Table) -> str:\n",
    "        indexes = self._inspector.get_indexes(table.name)\n",
    "        indexes_formatted = \"\\n\".join(map(_format_index, indexes))\n",
    "        return f\"Table Indexes:\\n{indexes_formatted}\"\n",
    "    \n",
    "    def _get_sample_rows(self, table: Table) -> str:\n",
    "\n",
    "        # build the select command\n",
    "        command = select(table).limit(self._sample_rows_in_table_info)\n",
    "\n",
    "        # save the command in string format\n",
    "        select_star = (\n",
    "            f\"SELECT * FROM '{table.name}' LIMIT \"\n",
    "            f\"{self._sample_rows_in_table_info}\"\n",
    "        )\n",
    "\n",
    "        # save the columns in string format\n",
    "        columns_str = \"\\t\".join([col.name for col in table.columns])\n",
    "\n",
    "        try:\n",
    "            # get the sample rows\n",
    "            with self._engine.connect() as connection:\n",
    "                try:\n",
    "                    sample_rows_result = connection.execute(command)\n",
    "                    # shorten values in the sample rows\n",
    "                \n",
    "                    sample_rows = list(\n",
    "                        map(lambda ls: [str(i)[:100]\n",
    "                            for i in ls], sample_rows_result)\n",
    "                    )\n",
    "                except TypeError as e:\n",
    "                    # print(\"***Back to literal querying...***\")\n",
    "                    sample_rows_result = connection.exec_driver_sql(select_star)\n",
    "                    # shorten values in the sample rows\n",
    "                    sample_rows = list(\n",
    "                        map(lambda ls: [str(i)[:100]\n",
    "                            for i in ls], sample_rows_result)\n",
    "                    )\n",
    "\n",
    "            # save the sample rows in string format\n",
    "            sample_rows_str = \"\\n\".join([\"\\t\".join(row) for row in sample_rows])\n",
    "        # in some dialects when there are no rows in the table a\n",
    "        # 'ProgrammingError' is returned\n",
    "        except ProgrammingError:\n",
    "            sample_rows_str = \"\"\n",
    "\n",
    "        return (\n",
    "            f\"{self._sample_rows_in_table_info} rows from {table.name} table:\\n\"\n",
    "            f\"{columns_str}\\n\"\n",
    "            f\"{sample_rows_str}\"\n",
    "        )\n",
    "\n",
    "    \n",
    "    def get_table_info_dict(self, table_names: Optional[List[str]] = None, with_col_details = True, do_sampling = True) -> dict:\n",
    "        all_table_names = self.get_usable_table_names()\n",
    "        if table_names is not None:\n",
    "            missing_tables = set(table_names).difference(all_table_names)\n",
    "            if missing_tables:\n",
    "                raise ValueError(f\"table_names {missing_tables} not found in database\")\n",
    "            all_table_names = table_names\n",
    "\n",
    "        meta_tables = [\n",
    "            tbl\n",
    "            for tbl in self._metadata.sorted_tables\n",
    "            if tbl.name in set(all_table_names)\n",
    "            and not (self.dialect == \"sqlite\" and tbl.name.startswith(\"sqlite_\"))\n",
    "        ]\n",
    "\n",
    "        tables = []\n",
    "        \"\"\"\n",
    "        all_table_names = self.get_table_names()\n",
    "        if table_names is not None:\n",
    "            missing_tables = set(table_names).difference(all_table_names)\n",
    "            if missing_tables:\n",
    "                raise ValueError(\n",
    "                    f\"table_names {missing_tables} not found in database\")\n",
    "            all_table_names = table_names\n",
    "\n",
    "        meta_tables = [\n",
    "            tbl\n",
    "            for tbl in self._metadata.sorted_tables\n",
    "            if tbl.name in set(all_table_names)\n",
    "            and not (self.dialect == \"sqlite\" and tbl.name.startswith(\"sqlite_\"))\n",
    "        ]\n",
    "\n",
    "        tables = []\n",
    "        \"\"\"\n",
    "\n",
    "        for table in meta_tables:\n",
    "            if self._custom_table_info and table.name in self._custom_table_info:\n",
    "                tables.append(self._custom_table_info[table.name])\n",
    "            else:\n",
    "                tables.append(table)\n",
    "        tables_dict = {}\n",
    "        for table in tables:\n",
    "            cols = []\n",
    "            col_details = []\n",
    "            pk = []\n",
    "            fks = []\n",
    "            sample_rows = []\n",
    "            num_rows = 0\n",
    "            if do_sampling:\n",
    "                sample_rows = self.get_tbl_samples_dict(table)\n",
    "                num_rows = self.get_rows_of_a_table(table)\n",
    "            for col in table.columns:\n",
    "                cols.append([col.name, str(col.type).split('.')[-1]])\n",
    "                if col.primary_key:\n",
    "                    pk.append(col.name)\n",
    "                if len(col.foreign_keys) > 0:\n",
    "                    for fk in list(col.foreign_keys):\n",
    "                        fks.append([f'{table.name}.{col.name}', '.'.join(fk.target_fullname.split('.')[-2:])])\n",
    "                        \n",
    "                if with_col_details and num_rows > 0:\n",
    "                    distinct_values = self.count_distinct_values_of_a_col(table, col)\n",
    "                    cardinality = len(distinct_values) / num_rows\n",
    "                    # here we use 3 simple conditions to filterout the categorical values:\n",
    "                    # 1. cardinality < 0.3\n",
    "                    # 2. total len(distinct_values) < 20\n",
    "                    # 3. '_id' not in name or name is not equal to 'id'\n",
    "                    if cardinality < 0.5 and len(distinct_values) < 20 and ('_id' not in col.name.lower() or col.name.lower() == 'id'): # maybe a categorical value\n",
    "                        col_details.append({'is_categorical': True, 'cardinality': cardinality, 'distinct_values': distinct_values})\n",
    "                    else:\n",
    "                        col_details.append({'is_categorical': False, 'cardinality': cardinality, 'distinct_values': distinct_values[:20]})\n",
    "            \n",
    "            tables_dict[table.name] = {\n",
    "                'COL': cols,\n",
    "                'PK': pk,\n",
    "                'FK': fks\n",
    "            }\n",
    "            if do_sampling:\n",
    "                tables_dict[table.name]['sample_rows'] = sample_rows\n",
    "            if with_col_details:\n",
    "                tables_dict[table.name]['COL_DETAILS'] = col_details \n",
    "        return tables_dict\n",
    "    \n",
    "    def get_tbl_samples_dict(self, table):\n",
    "        sample_rows_dict = {}\n",
    "        if self._sample_rows_in_table_info:\n",
    "            # build the select command\n",
    "            command = select(table).limit(self._sample_rows_in_table_info)\n",
    "\n",
    "            # save the command in string format\n",
    "            select_star = (\n",
    "                f\"SELECT * FROM '{table.name}' LIMIT \"\n",
    "                f\"{self._sample_rows_in_table_info}\"\n",
    "            )\n",
    "\n",
    "            # save the columns\n",
    "            columns = [col.name for col in table.columns]\n",
    "\n",
    "            # get the sample rows\n",
    "            try:\n",
    "                with self._engine.connect() as connection:\n",
    "                    try:\n",
    "                        sample_rows = connection.execute(command)\n",
    "                        # shorten values in the sample rows\n",
    "                    \n",
    "                        sample_rows = list(\n",
    "                            map(lambda ls: [str(i)[:100]\n",
    "                                for i in ls], sample_rows)\n",
    "                        )\n",
    "                    except TypeError as e:\n",
    "                        # print(\"***Back to literal querying...***\")\n",
    "                        sample_rows = connection.exec_driver_sql(select_star)\n",
    "                        # shorten values in the sample rows\n",
    "                        sample_rows = list(\n",
    "                            map(lambda ls: [str(i)[:100]\n",
    "                                for i in ls], sample_rows)\n",
    "                        )\n",
    "                sample_rows_T = list(map(list, zip(*sample_rows)))\n",
    "                for col, rows in zip(columns, sample_rows_T):\n",
    "                    sample_rows_dict[col] = rows\n",
    "            except ProgrammingError:\n",
    "                print('Warning: sampling error')\n",
    "                sample_rows_dict = {}\n",
    "        return sample_rows_dict\n",
    "\n",
    "    def get_rows_of_a_table(self, table):\n",
    "        command = select(func.count()).select_from(table)\n",
    "        try:\n",
    "            with self._engine.connect() as connection:\n",
    "                num_rows = connection.execute(command)\n",
    "                # print(table.name)\n",
    "                return num_rows.scalar()\n",
    "        except ProgrammingError:\n",
    "                warnings.warn('Fetching categorical values error')\n",
    "                return None\n",
    "    \n",
    "    def count_distinct_values_of_a_col(self, table, column, num_limit=100):\n",
    "        command = select(func.count(column), column).group_by(column).order_by(func.count(column).desc()).limit(num_limit)\n",
    "        try:\n",
    "            with self._engine.connect() as connection:\n",
    "                try:\n",
    "                    sample_rows = connection.execute(command).fetchall()\n",
    "                    # print(table.name, column.name)\n",
    "                    return [list(r) for r in sample_rows]\n",
    "                except ValueError as e:\n",
    "                    print(f\"ValueError: {e.__traceback__}\")\n",
    "                    # backdraw to use exec_driver_sql method\n",
    "                    select_str = (\n",
    "                        f\"SELECT COUNT(*) FROM '{table.name}' GROUP BY {column.name} ORDER BY COUNT(*) LIMIT {num_limit}\")\n",
    "                    sample_rows = connection.exec_driver_sql(select_str).fetchall()\n",
    "                    return [list(r) for r in sample_rows]\n",
    "        except ProgrammingError:\n",
    "                print('Warning: categorical error')\n",
    "                return []\n",
    "    \n",
    "    def run(self, command: str, fetch: str = \"all\", fmt = \"str\") -> str:\n",
    "        \"\"\"Execute a SQL command and return a string representing the results.\n",
    "\n",
    "        If the statement returns rows, a string of the results is returned.\n",
    "        If the statement returns no rows, an empty string is returned.\n",
    "        \"\"\"\n",
    "        with self._engine.begin() as connection:\n",
    "            if self._schema is not None:\n",
    "                if self.dialect == \"snowflake\":\n",
    "                    connection.exec_driver_sql(\n",
    "                        f\"ALTER SESSION SET search_path='{self._schema}'\"\n",
    "                    )\n",
    "                elif self.dialect == \"bigquery\":\n",
    "                    connection.exec_driver_sql(f\"SET @@dataset_id='{self._schema}'\")\n",
    "                else:\n",
    "                    connection.exec_driver_sql(f\"SET search_path TO {self._schema}\")\n",
    "            cursor = connection.execute(text(command))\n",
    "            if cursor.returns_rows:\n",
    "                if fetch == \"all\":\n",
    "                    result = cursor.fetchall()\n",
    "                elif fetch == \"one\":\n",
    "                    result = cursor.fetchone()[0]\n",
    "                else:\n",
    "                    raise ValueError(\n",
    "                        \"Fetch parameter must be either 'one' or 'all'\")\n",
    "                if fmt == \"str\":\n",
    "                    return str(result)\n",
    "                elif fmt == \"list\":\n",
    "                    return list(result)\n",
    "        return \"\"\n",
    "\n",
    "    def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:\n",
    "        \"\"\"Get information about specified tables.\n",
    "\n",
    "        Follows best practices as specified in: Rajkumar et al, 2022\n",
    "        (https://arxiv.org/abs/2204.00498)\n",
    "\n",
    "        If `sample_rows_in_table_info`, the specified number of sample rows will be\n",
    "        appended to each table description. This can increase performance as\n",
    "        demonstrated in the paper.\n",
    "        \"\"\"\n",
    "        try:\n",
    "            return self.get_table_info(table_names)\n",
    "        except ValueError as e:\n",
    "            \"\"\"Format the error message\"\"\"\n",
    "            return f\"Error: {e}\"\n",
    "\n",
    "    def run_no_throw(self, command: str, fetch: str = \"all\") -> str:\n",
    "        \"\"\"Execute a SQL command and return a string representing the results.\n",
    "\n",
    "        If the statement returns rows, a string of the results is returned.\n",
    "        If the statement returns no rows, an empty string is returned.\n",
    "\n",
    "        If the statement throws an error, the error message is returned.\n",
    "        \"\"\"\n",
    "        try:\n",
    "            return self.run(command, fetch)\n",
    "        except SQLAlchemyError as e:\n",
    "            \"\"\"Format the error message\"\"\"\n",
    "            return f\"Error: {e}\"\n",
    "        \n",
    "    def dict2str(self, d):\n",
    "        text = []\n",
    "        for t,v in d.items():\n",
    "            _tbl = f'{t}:'\n",
    "            cols = []\n",
    "            pks = ['PK:']\n",
    "            fks = ['FK:']\n",
    "            for col in v['COL']:\n",
    "                cols.append(f'{col[0]}:{self.aliastype(col[1])}')\n",
    "            for pk in v['PK']:\n",
    "                pks.append(pk)\n",
    "            for fk in v['FK']:\n",
    "                fks.append('='.join(list(fk)))\n",
    "            \n",
    "            tbl = '\\n'.join([_tbl, ', '.join(cols), ' '.join(pks), ' '.join(fks)])\n",
    "            text.append(tbl)\n",
    "        return '\\n'.join(text)\n",
    "        \n",
    "    def aliastype(self, t):\n",
    "        _t = t[:3].lower()\n",
    "        if _t in ['int', 'tin', 'sma', 'med', 'big','uns', 'rea', 'dou', 'num', 'dec', 'tim']:\n",
    "            res = 'N' # numerical value\n",
    "        elif _t in ['tex', 'var', 'cha', 'nch', 'nat', 'nva', 'clo']:\n",
    "            res = 'T' # text value\n",
    "        elif _t in ['boo']:\n",
    "            res = 'B'\n",
    "        elif _t in ['dat']:\n",
    "            res = 'D'\n",
    "        else:\n",
    "            raise ValueError('Unsupported data type')\n",
    "        return res\n",
    "    \n",
    "    @lru_cache(maxsize=1000)\n",
    "    def cached_query_results(self, query, limit_num=100):\n",
    "        session = self.Session()\n",
    "        try:\n",
    "            result = session.execute(query).fetchmany(limit_num)\n",
    "            return result\n",
    "        except Exception as e:\n",
    "            print(f\"Error executing query: {query}. Error: {str(e)}\")\n",
    "            raise\n",
    "        finally:\n",
    "            session.close()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "db.get_table_info_dict(do_sampling=False, with_col_details=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}