diff --git a/GANBLR/GANBLR_Extend.ipynb b/GANBLR/GANBLR_Extend.ipynb new file mode 100644 index 0000000..6bd6837 --- /dev/null +++ b/GANBLR/GANBLR_Extend.ipynb @@ -0,0 +1,2269 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 66, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "adkv0_8qNCbK", + "outputId": "3609e0f9-8d59-40b0-c1c0-87f36d81208f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n", + "Requirement already satisfied: pgmpy in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (0.1.13)\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: You are using pip version 20.2.3; however, version 22.0.4 is available.\n", + "You should consider upgrading via the 'C:\\Users\\rukam\\AppData\\Local\\Programs\\Python\\Python38\\python.exe -m pip install --upgrade pip' command.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: torch in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from pgmpy) (1.7.1)\n", + "Requirement already satisfied: pandas in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from pgmpy) (1.2.1)\n", + "Requirement already satisfied: statsmodels in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from pgmpy) (0.12.2)\n", + "Requirement already satisfied: networkx in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from pgmpy) (2.5)\n", + "Requirement already satisfied: tqdm in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from pgmpy) (4.56.2)\n", + "Requirement already satisfied: numpy in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from pgmpy) (1.18.5)\n", + "Requirement already satisfied: scipy in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from pgmpy) (1.6.0)\n", + "Requirement already satisfied: scikit-learn in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from pgmpy) (0.24.1)\n", + "Requirement already satisfied: pyparsing in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from pgmpy) (2.4.7)\n", + "Requirement already satisfied: joblib in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from pgmpy) (1.0.0)\n", + "Requirement already satisfied: typing-extensions in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from torch->pgmpy) (3.7.4.3)\n", + "Requirement already satisfied: pytz>=2017.3 in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from pandas->pgmpy) (2020.5)\n", + "Requirement already satisfied: python-dateutil>=2.7.3 in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from pandas->pgmpy) (2.8.1)\n", + "Requirement already satisfied: patsy>=0.5 in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from statsmodels->pgmpy) (0.5.1)\n", + "Requirement already satisfied: decorator>=4.3.0 in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from networkx->pgmpy) (4.4.2)\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from scikit-learn->pgmpy) (2.1.0)\n", + "Requirement already satisfied: six>=1.5 in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from python-dateutil>=2.7.3->pandas->pgmpy) (1.15.0)\n", + "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simpleNote: you may need to restart the kernel to use updated packages.\n", + "Requirement already satisfied: pyitlib in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (0.2.2)\n", + "Requirement already satisfied: scikit-learn>=0.16.0 in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from pyitlib) (0.24.1)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: You are using pip version 20.2.3; however, version 22.0.4 is available.\n", + "You should consider upgrading via the 'C:\\Users\\rukam\\AppData\\Local\\Programs\\Python\\Python38\\python.exe -m pip install --upgrade pip' command.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Requirement already satisfied: pandas>=0.20.2numpy>=1.9.2 in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from pyitlib) (1.2.1)\n", + "Requirement already satisfied: future>=0.16.0 in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from pyitlib) (0.18.2)\n", + "Requirement already satisfied: scipy>=1.0.1 in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from pyitlib) (1.6.0)\n", + "Requirement already satisfied: numpy>=1.13.3 in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from scikit-learn>=0.16.0->pyitlib) (1.18.5)\n", + "Requirement already satisfied: joblib>=0.11 in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from scikit-learn>=0.16.0->pyitlib) (1.0.0)\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from scikit-learn>=0.16.0->pyitlib) (2.1.0)\n", + "Requirement already satisfied: pytz>=2017.3 in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from pandas>=0.20.2numpy>=1.9.2->pyitlib) (2020.5)\n", + "Requirement already satisfied: python-dateutil>=2.7.3 in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from pandas>=0.20.2numpy>=1.9.2->pyitlib) (2.8.1)\n", + "Requirement already satisfied: six>=1.5 in c:\\users\\rukam\\appdata\\local\\programs\\python\\python38\\lib\\site-packages (from python-dateutil>=2.7.3->pandas>=0.20.2numpy>=1.9.2->pyitlib) (1.15.0)\n" + ] + } + ], + "source": [ + "%pip install pgmpy\n", + "%pip install pyitlib" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import tensorflow as tf\n", + "from tensorflow.python.ops import math_ops, array_ops\n", + "from tensorflow.keras.constraints import Constraint\n", + "from tensorflow.keras.activations import softmax\n", + "class softmax_weight(Constraint):\n", + " \"\"\"Constrains weight tensors to be under softmax `.\"\"\"\n", + " \n", + " def __init__(self,feature_uniques):\n", + " idxs = math_ops.cumsum([0] + feature_uniques)\n", + " idxs = [i.numpy() for i in idxs]\n", + " self.feature_idxs = [\n", + " (idxs[i],idxs[i+1]) for i in range(len(idxs)-1)\n", + " ]\n", + " \n", + " def __call__(self, w): \n", + " w_new = [\n", + " math_ops.log(softmax(w[i:j,:], axis=0))\n", + " for i,j in self.feature_idxs\n", + " ]\n", + " return tf.concat(w_new, 0)\n", + " \n", + " def get_config(self):\n", + " return {'feature_idxs': self.feature_idxs}" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [], + "source": [ + "def sample(*arrays, size=None, frac=None):\n", + " '''\n", + " random sample from arrays.\n", + " \n", + " Note: arrays must be equal-length\n", + " \n", + " size = None (default) indicate that return a permutation of given arrays.\n", + " '''\n", + " if len(arrays) < 1:\n", + " return None\n", + " if frac is not None and frac <= 1 and frac > 0:\n", + " size = int(len(arrays[0]) * frac)\n", + " if size is None:\n", + " size = len(arrays[0])\n", + " \n", + " random_idxs = np.random.permutation(len(arrays[0]))[:size]\n", + " results = []\n", + " for arr in arrays:\n", + " results.append(arr[random_idxs])\n", + " return results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# GANBLR" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": {}, + "outputs": [], + "source": [ + "def elr_loss(KL_LOSS):\n", + " def loss(y_true, y_pred):\n", + " return tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)+ KL_LOSS\n", + " return loss" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": {}, + "outputs": [], + "source": [ + "#from tensorflow.keras.layers import Concatenate\n", + "from tensorflow.keras.backend import clear_session\n", + "from tensorflow.keras.layers import Dense\n", + "from sklearn.preprocessing import OneHotEncoder\n", + "class GANBLR:\n", + " def __init__(self, log=True):\n", + " self.g = None\n", + " self.d = None\n", + " self.log = log\n", + " self.g_history = []\n", + " self.d_history = []\n", + " \n", + " self.batch_size = 32 # default value\n", + " self.feature_uniques = []\n", + " self.class_unique = 0\n", + " self.g_input_dim = 0\n", + " self.d_input_dim = 0\n", + " self.y_counts = []\n", + " \n", + " def generate(self, size=None, ohe=False):\n", + " from pgmpy.models import NaiveBayes\n", + " from pgmpy.sampling import BayesianModelSampling\n", + " from pgmpy.factors.discrete import TabularCPD\n", + " #basic varibles\n", + " weights = self.g.get_weights()[0]\n", + " n_features = len(self.feature_uniques)\n", + " n_classes = weights.shape[1]\n", + " n_samples = np.sum(self.class_counts)\n", + " #cut weights by feature uniques\n", + " idxs = np.cumsum([0] + self.feature_uniques)\n", + " feature_idxs = [(idxs[i],idxs[i+1]) for i in range(len(idxs)-1)]\n", + " feature_names = [str(i) for i in range(n_features)]\n", + " #get cpd of features\n", + " feature_probs = np.exp(weights)\n", + " feature_cpd_probs = [feature_probs[start:end,:] for start, end in feature_idxs]\n", + " feature_cpd_probs = [p/p.sum(axis=0,keepdims=1) for p in feature_cpd_probs]\n", + " feature_cpds = [\n", + " TabularCPD(name, n_unique, table, evidence=['y'], evidence_card=[n_classes])\n", + " for name, n_unique, table in zip(feature_names, self.feature_uniques, feature_cpd_probs)\n", + " ]\n", + " #get cpd of label\n", + " y_probs = (self.class_counts/n_samples).reshape(-1,1)\n", + " y_cpd = TabularCPD('y', n_classes, y_probs)\n", + " \n", + " #define the model\n", + " elr = NaiveBayes(feature_names, 'y')\n", + " elr.add_cpds(y_cpd, *feature_cpds)\n", + " #sampling\n", + " sample_size = n_samples if size is None else size\n", + " result = BayesianModelSampling(elr).forward_sample(size=sample_size)\n", + " sorted_result = result[feature_names + ['y']].values\n", + " #return\n", + " syn_X, syn_y = sorted_result[:,:-1], sorted_result[:,-1]\n", + " \n", + " if ohe:\n", + " ohe_syn_X = [np.eye(b)[syn_X[:,i]] for i, b in enumerate(self.feature_uniques)]\n", + " ohe_syn_X = np.hstack(ohe_syn_X)\n", + " return ohe_syn_X, syn_y\n", + " else:\n", + " return sorted_result\n", + " \n", + " def fit(self, X, y, epochs\n", + " , batch_size=32, warm_up_epochs=10, categories_=None):\n", + " \n", + " if categories_ is None:\n", + " categories_ = 'auto'\n", + " ohe = OneHotEncoder(categories=categories_).fit(X)\n", + " ohe_X = ohe.transform(X).toarray()\n", + " #feature_uniques = [len(np.unique(X[:,i])) for i in range(X.shape[1])]\n", + " self.feature_uniques = [len(c) for c in ohe.categories_]\n", + " y_unique, y_counts = np.unique(y, return_counts=True)\n", + " self.class_unique = len(y_unique)\n", + " self.class_counts = y_counts\n", + " self.g_input_dim = np.sum(self.feature_uniques)\n", + " self.d_input_dim = X.shape[1]\n", + " self.batch_size = batch_size\n", + " self._build_g()\n", + " self._build_d()\n", + " \n", + " #warm up\n", + " self.g.fit(ohe_X, y, epochs=warm_up_epochs, batch_size=batch_size)\n", + " syn_data = self.generate(size=len(X))\n", + " #real_data = np.concatenate([X, y.reshape(-1,1)], axis=-1)\n", + " for i in range(epochs):\n", + " #prepare data\n", + " real_label = np.ones(len(X))\n", + " syn_label = np.zeros(len(X))\n", + " disc_label = np.concatenate([real_label, syn_label]) \n", + " disc_X = np.vstack([X, syn_data[:,:-1]]) \n", + " disc_X, disc_label = sample(disc_X, disc_label, frac=0.8)\n", + " #train d\n", + " self._train_d(disc_X, disc_label)\n", + " prob_fake = self.d.predict(X)\n", + " ls = np.mean(-np.log(np.subtract(1,prob_fake)))\n", + " #train g\n", + " self._train_g(ohe_X, y, loss=ls)\n", + " syn_data = self.generate(size=len(X))\n", + "\n", + " def _train_g(self, X, y, epochs=1, loss=None):\n", + " if loss is not None:\n", + " clear_session()\n", + " self._build_g(weights=self.g.get_weights(), loss=loss)\n", + " self._build_d(weights=self.d.get_weights())\n", + " \n", + " history = self.g.fit(X, y, epochs=epochs, batch_size=self.batch_size)\n", + " if self.log:\n", + " self.g_history.append(history.history)\n", + " \n", + " def _train_d(self, X, y, epochs=1):\n", + " history = self.d.fit(X, y, batch_size=self.batch_size, epochs=epochs) \n", + " if self.log:\n", + " self.d_history.append(history.history) \n", + " \n", + " def _build_g(self, weights=None, loss=None):\n", + " if loss is None:\n", + " loss = elr_loss(0)\n", + " else:\n", + " loss = elr_loss(loss)\n", + " constraint = softmax_weight(self.feature_uniques)\n", + " g = tf.keras.Sequential() \n", + " g.add(Dense(self.class_unique, input_dim=self.g_input_dim, activation='softmax',kernel_constraint=constraint))\n", + " g.compile(loss=loss, optimizer='adam', metrics=['accuracy'])\n", + " self.g = g\n", + " \n", + " if weights is not None:\n", + " g.set_weights(weights)\n", + " return g\n", + " \n", + " def _build_d(self, weights=None):\n", + " d = tf.keras.Sequential()\n", + " d.add(Dense(1, input_dim=self.d_input_dim, activation='sigmoid'))\n", + " d.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])\n", + " \n", + " self.d = d\n", + " \n", + " if weights is not None:\n", + " d.set_weights(weights)\n", + " return d" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Ganblr Extend" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "metadata": {}, + "outputs": [], + "source": [ + "#from tensorflow.keras.layers import Concatenate\n", + "from tensorflow.keras.backend import clear_session\n", + "from tensorflow.keras.layers import Dense\n", + "from sklearn.preprocessing import OneHotEncoder\n", + "class GANBLR_Ext:\n", + " def __init__(self, log=True):\n", + " self.g = None\n", + " self.d = None\n", + " self.log = log\n", + " self.g_history = []\n", + " self.d_history = []\n", + " \n", + " self.batch_size = 32 # default value\n", + " self.feature_uniques = []\n", + " self.class_unique = 0\n", + " self.g_input_dim = 0\n", + " self.d_input_dim = 0\n", + " self.y_counts = []\n", + " \n", + " def generate(self, size=None, ohe=False):\n", + " from pgmpy.models import NaiveBayes\n", + " from pgmpy.sampling import BayesianModelSampling\n", + " from pgmpy.factors.discrete import TabularCPD\n", + " #basic varibles\n", + " weights = self.g.get_weights()[0]\n", + " n_features = len(self.feature_uniques)\n", + " n_classes = weights.shape[1]\n", + " n_samples = np.sum(self.class_counts)\n", + " #cut weights by feature uniques\n", + " idxs = np.cumsum([0] + self.feature_uniques)\n", + " feature_idxs = [(idxs[i],idxs[i+1]) for i in range(len(idxs)-1)]\n", + " feature_names = [str(i) for i in range(n_features)]\n", + " #get cpd of features\n", + " feature_probs = np.exp(weights)\n", + " feature_cpd_probs = [feature_probs[start:end,:] for start, end in feature_idxs]\n", + " feature_cpd_probs = [p/p.sum(axis=0,keepdims=1) for p in feature_cpd_probs]\n", + " feature_cpds = [\n", + " TabularCPD(name, n_unique, table, evidence=['y'], evidence_card=[n_classes])\n", + " for name, n_unique, table in zip(feature_names, self.feature_uniques, feature_cpd_probs)\n", + " ]\n", + " #get cpd of label\n", + " y_probs = (self.class_counts/n_samples).reshape(-1,1)\n", + " y_cpd = TabularCPD('y', n_classes, y_probs)\n", + " \n", + " #define the model\n", + " elr = NaiveBayes(feature_names, 'y')\n", + " elr.add_cpds(y_cpd, *feature_cpds)\n", + " #sampling\n", + " sample_size = n_samples if size is None else size\n", + " result = BayesianModelSampling(elr).forward_sample(size=sample_size)\n", + " sorted_result = result[feature_names + ['y']].values\n", + " #return\n", + " syn_X, syn_y = sorted_result[:,:-1], sorted_result[:,-1]\n", + " \n", + " if ohe:\n", + " ohe_syn_X = [np.eye(b)[syn_X[:,i]] for i, b in enumerate(self.feature_uniques)]\n", + " ohe_syn_X = np.hstack(ohe_syn_X)\n", + " return ohe_syn_X, syn_y\n", + " else:\n", + " return sorted_result\n", + " \n", + " def fit(self, X, y, epochs\n", + " , batch_size=32, warm_up_epochs=10, categories_=None):\n", + " \n", + " if categories_ is None:\n", + " categories_ = 'auto'\n", + " ohe = OneHotEncoder(categories=categories_).fit(X)\n", + " ohe_X = ohe.transform(X).toarray()\n", + " #feature_uniques = [len(np.unique(X[:,i])) for i in range(X.shape[1])]\n", + " self.feature_uniques = [len(c) for c in ohe.categories_]\n", + " y_unique, y_counts = np.unique(y, return_counts=True)\n", + " self.class_unique = len(y_unique)\n", + " self.class_counts = y_counts\n", + " self.g_input_dim = np.sum(self.feature_uniques)\n", + " self.d_input_dim = X.shape[1] + self.class_unique\n", + " self.batch_size = batch_size\n", + " self._build_g()\n", + " self._build_d()\n", + " \n", + " real_label = np.ones(len(X))\n", + " syn_label = np.zeros(len(X))\n", + " #warm up\n", + " self.g.fit(ohe_X, y, epochs=warm_up_epochs, batch_size=batch_size)\n", + " syn_data = self.generate(size=len(X))\n", + " #real_data = np.concatenate([X, y.reshape(-1,1)], axis=-1)\n", + " for i in range(epochs):\n", + " #prepare data \n", + " disc_label = np.concatenate([real_label, syn_label])\n", + " disc_X = np.vstack([X, syn_data[:,:-1]])\n", + " g_logits = self.g(tf.convert_to_tensor(ohe.transform(disc_X).toarray())).numpy()\n", + " disc_X = np.hstack([g_logits, disc_X])\n", + " disc_X, disc_label = sample(disc_X, disc_label, frac=0.8)\n", + " #train d\n", + " self._train_d(disc_X, disc_label)\n", + " prob_fake = self.d.predict(np.hstack([g_logits[:len(X)], X]))\n", + " ls = np.mean(-np.log(np.subtract(1,prob_fake)))\n", + " #train g\n", + " self._train_g(ohe_X, y, loss=ls)\n", + " syn_data = self.generate(size=len(X))\n", + "\n", + " def _train_g(self, X, y, epochs=1, loss=None):\n", + " if loss is not None:\n", + " g_weights = self.g.get_weights()\n", + " d_weights = self.d.get_weights()\n", + " clear_session()\n", + " self._build_g(weights=g_weights, loss=loss)\n", + " self._build_d(weights=d_weights)\n", + " \n", + " history = self.g.fit(X, y, epochs=epochs, batch_size=self.batch_size)\n", + " if self.log:\n", + " self.g_history.append(history.history)\n", + " \n", + " def _train_d(self, X, y, epochs=1):\n", + " history = self.d.fit(X, y, batch_size=self.batch_size, epochs=epochs) \n", + " if self.log:\n", + " self.d_history.append(history.history) \n", + " \n", + " def _build_g(self, weights=None, loss=None):\n", + " if loss is None:\n", + " loss = elr_loss(0)\n", + " else:\n", + " loss = elr_loss(loss)\n", + " constraint = softmax_weight(self.feature_uniques)\n", + " g = tf.keras.Sequential() \n", + " g.add(Dense(self.class_unique, input_dim=self.g_input_dim, activation='softmax',kernel_constraint=constraint))\n", + " g.compile(loss=loss, optimizer='adam', metrics=['accuracy'])\n", + " self.g = g\n", + " \n", + " if weights is not None:\n", + " g.set_weights(weights)\n", + " return g\n", + " \n", + " def _build_d(self, weights=None):\n", + " d = tf.keras.Sequential()\n", + " d.add(Dense(1, input_dim=self.d_input_dim, activation='sigmoid'))\n", + " d.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])\n", + " \n", + " self.d = d\n", + " \n", + " if weights is not None:\n", + " d.set_weights(weights)\n", + " return d" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Test" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(12960, 9)\n", + "5\n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "df = pd.read_csv('car_value.csv',dtype='category')\n", + "df1 = OrdinalEncoder().fit_transform(df).astype('int')\n", + "print(df.shape)\n", + "X = df1[:,0:-1]\n", + "y = df1[:,-1] \n", + "print(len(np.unique(y)))" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "metadata": {}, + "outputs": [], + "source": [ + "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)\n", + "categories = OneHotEncoder().fit(X).categories_" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/5\n", + "203/203 [==============================] - 0s 1ms/step - loss: 1.4299 - accuracy: 0.3100\n", + "Epoch 2/5\n", + "203/203 [==============================] - 0s 1ms/step - loss: 1.0296 - accuracy: 0.6759\n", + "Epoch 3/5\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.8309 - accuracy: 0.8576\n", + "Epoch 4/5\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.7016 - accuracy: 0.8901\n", + "Epoch 5/5\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.6102 - accuracy: 0.9039\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 34.48it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.7771 - accuracy: 0.4924\n", + "203/203 [==============================] - 0s 1ms/step - loss: 1.2579 - accuracy: 0.9079\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 35.86it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.7313 - accuracy: 0.4951\n", + "203/203 [==============================] - 0s 2ms/step - loss: 1.1829 - accuracy: 0.9113\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 36.86it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.7041 - accuracy: 0.5121\n", + "203/203 [==============================] - 0s 2ms/step - loss: 1.1240 - accuracy: 0.9131\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 36.23it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6967 - accuracy: 0.5179\n", + "203/203 [==============================] - 0s 1ms/step - loss: 1.0586 - accuracy: 0.9128\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 36.88it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6943 - accuracy: 0.5245\n", + "203/203 [==============================] - 0s 1ms/step - loss: 1.0139 - accuracy: 0.9145\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 35.86it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6912 - accuracy: 0.5403\n", + "203/203 [==============================] - 0s 1ms/step - loss: 1.0286 - accuracy: 0.9133\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.81it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6906 - accuracy: 0.5398\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9920 - accuracy: 0.9144\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 35.85it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6892 - accuracy: 0.5478\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9704 - accuracy: 0.9157\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 35.71it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6870 - accuracy: 0.5537\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9713 - accuracy: 0.9127\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.66it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6861 - accuracy: 0.5578\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9666 - accuracy: 0.9136\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 38.29it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6862 - accuracy: 0.5560\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9445 - accuracy: 0.9139\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 38.44it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6848 - accuracy: 0.5574\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9512 - accuracy: 0.9148\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 38.46it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6839 - accuracy: 0.5639\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9416 - accuracy: 0.9165\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.04it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6814 - accuracy: 0.5673\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9426 - accuracy: 0.9171\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.50it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6823 - accuracy: 0.5659\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9191 - accuracy: 0.9174\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 39.30it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6805 - accuracy: 0.5731\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9257 - accuracy: 0.9184\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 35.64it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6811 - accuracy: 0.5719\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9610 - accuracy: 0.9193\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.97it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6768 - accuracy: 0.5859\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9500 - accuracy: 0.9196\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 39.65it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6747 - accuracy: 0.5859\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9261 - accuracy: 0.9207\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.66it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6772 - accuracy: 0.5854\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9510 - accuracy: 0.9215\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 34.75it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6743 - accuracy: 0.5859\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9426 - accuracy: 0.9218\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.49it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6749 - accuracy: 0.5890\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9724 - accuracy: 0.9230\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.46it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6692 - accuracy: 0.6020\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9804 - accuracy: 0.9239\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 36.29it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6719 - accuracy: 0.5933\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9629 - accuracy: 0.9242\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.50it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6679 - accuracy: 0.6005\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9274 - accuracy: 0.9247\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.66it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - ETA: 0s - loss: 0.6685 - accuracy: 0.60 - 0s 1ms/step - loss: 0.6682 - accuracy: 0.6029\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9976 - accuracy: 0.9269\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 39.65it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6672 - accuracy: 0.6023\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9732 - accuracy: 0.9269\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 36.88it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6661 - accuracy: 0.6019\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9752 - accuracy: 0.9258\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 35.62it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6667 - accuracy: 0.6061\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9386 - accuracy: 0.9264\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.34it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6666 - accuracy: 0.6019\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9627 - accuracy: 0.9267\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 35.86it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6682 - accuracy: 0.6010\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9339 - accuracy: 0.9278\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.19it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6664 - accuracy: 0.6043\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9462 - accuracy: 0.9267\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 35.57it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6630 - accuracy: 0.6055\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9854 - accuracy: 0.9270\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 35.16it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6651 - accuracy: 0.6102\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9378 - accuracy: 0.9272\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.35it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6648 - accuracy: 0.6072\n", + "203/203 [==============================] - 0s 2ms/step - loss: 0.9387 - accuracy: 0.9276\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 35.02it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6640 - accuracy: 0.6080\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9752 - accuracy: 0.9273\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 38.13it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6609 - accuracy: 0.6096\n", + "203/203 [==============================] - 0s 2ms/step - loss: 0.9613 - accuracy: 0.9269\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 36.89it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6630 - accuracy: 0.6109\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9358 - accuracy: 0.9267\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 38.79it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6622 - accuracy: 0.6125\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9638 - accuracy: 0.9272\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 39.65it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6599 - accuracy: 0.6061\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9538 - accuracy: 0.9265\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.63it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6588 - accuracy: 0.6127\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9678 - accuracy: 0.9269\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 35.86it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6631 - accuracy: 0.6054\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9916 - accuracy: 0.9267\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 31.36it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6605 - accuracy: 0.6136\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9814 - accuracy: 0.9267\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 39.30it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6613 - accuracy: 0.6071\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9246 - accuracy: 0.9272\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 38.78it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6595 - accuracy: 0.6136: 0s - loss: 0.6579 - accuracy: 0.\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9796 - accuracy: 0.9265\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.50it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6602 - accuracy: 0.6071\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9723 - accuracy: 0.9261\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 33.20it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 2ms/step - loss: 0.6611 - accuracy: 0.6101\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9841 - accuracy: 0.9262\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 36.29it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6594 - accuracy: 0.6098\n", + "203/203 [==============================] - 0s 1ms/step - loss: 1.0471 - accuracy: 0.9265\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 39.47it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6614 - accuracy: 0.6090\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9977 - accuracy: 0.9267\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 36.59it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6613 - accuracy: 0.6069\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9651 - accuracy: 0.9262\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 38.59it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/5\n", + "203/203 [==============================] - 0s 1ms/step - loss: 1.3941 - accuracy: 0.3898\n", + "Epoch 2/5\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9942 - accuracy: 0.6287\n", + "Epoch 3/5\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.8029 - accuracy: 0.8136\n", + "Epoch 4/5\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.6794 - accuracy: 0.8795\n", + "Epoch 5/5\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.5926 - accuracy: 0.9020\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.03it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.8877 - accuracy: 0.4939\n", + "203/203 [==============================] - 0s 1ms/step - loss: 1.3265 - accuracy: 0.9093\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.41it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.7820 - accuracy: 0.4959\n", + "203/203 [==============================] - 0s 1ms/step - loss: 1.1762 - accuracy: 0.9111\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 36.44it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.7363 - accuracy: 0.4924\n", + "203/203 [==============================] - 0s 1ms/step - loss: 1.1135 - accuracy: 0.9142\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 36.88it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.7098 - accuracy: 0.5019\n", + "203/203 [==============================] - 0s 1ms/step - loss: 1.0600 - accuracy: 0.9145\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 36.44it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6989 - accuracy: 0.5061\n", + "203/203 [==============================] - 0s 1ms/step - loss: 1.0151 - accuracy: 0.9136\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 38.62it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6932 - accuracy: 0.5195\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9743 - accuracy: 0.9139\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 35.02it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6916 - accuracy: 0.5261\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9653 - accuracy: 0.9133\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 34.09it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6902 - accuracy: 0.5325\n", + "203/203 [==============================] - 0s 2ms/step - loss: 0.9431 - accuracy: 0.9137\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.14it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6893 - accuracy: 0.5393\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9411 - accuracy: 0.9147\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 36.54it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6891 - accuracy: 0.5427\n", + "203/203 [==============================] - 0s 2ms/step - loss: 0.9281 - accuracy: 0.9127\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 25.14it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6883 - accuracy: 0.5441\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9313 - accuracy: 0.9151\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.27it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6878 - accuracy: 0.5496\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9031 - accuracy: 0.9130\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 30.31it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6884 - accuracy: 0.5399\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9682 - accuracy: 0.9151\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.34it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6884 - accuracy: 0.5420\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.8978 - accuracy: 0.9156\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 35.42it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6864 - accuracy: 0.5545\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9149 - accuracy: 0.9181\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 36.59it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6859 - accuracy: 0.5550\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9384 - accuracy: 0.9170\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 38.88it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6863 - accuracy: 0.5544\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.8889 - accuracy: 0.9185\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 36.58it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6867 - accuracy: 0.5536\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.8988 - accuracy: 0.9193\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 35.43it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6865 - accuracy: 0.5496\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9020 - accuracy: 0.9193\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 34.88it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6866 - accuracy: 0.5553\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9137 - accuracy: 0.9202\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 36.44it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6847 - accuracy: 0.5589\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9006 - accuracy: 0.9218: 0s - loss: 0.9001 - accuracy: 0.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 15.37it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6871 - accuracy: 0.5525\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.8774 - accuracy: 0.9233\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 36.00it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6843 - accuracy: 0.5677\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.8852 - accuracy: 0.9228\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 39.30it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6849 - accuracy: 0.5590\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9181 - accuracy: 0.9259\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.50it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6839 - accuracy: 0.5661: 0s - loss: 0.6838 - accuracy: 0.56\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9038 - accuracy: 0.9242\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 38.13it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6838 - accuracy: 0.5642\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9089 - accuracy: 0.9250\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.19it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6839 - accuracy: 0.5652\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.8945 - accuracy: 0.9258\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 39.82it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6840 - accuracy: 0.5627\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9196 - accuracy: 0.9253\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.97it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6839 - accuracy: 0.5657\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.8914 - accuracy: 0.9269\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 39.13it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " 1/324 [..............................] - ETA: 0s - loss: 0.6906 - accuracy: 0.6562WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0000s vs `on_train_batch_end` time: 0.0010s). Check your callbacks.\n", + "324/324 [==============================] - 0s 1ms/step - loss: 0.6855 - accuracy: 0.5607\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9037 - accuracy: 0.9267\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 36.27it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6829 - accuracy: 0.5672\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9273 - accuracy: 0.9267\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 36.42it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6839 - accuracy: 0.5667\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.8825 - accuracy: 0.9261\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 38.13it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6830 - accuracy: 0.5634\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9030 - accuracy: 0.9255\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 36.29it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6829 - accuracy: 0.5672\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9127 - accuracy: 0.9259\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 35.99it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6835 - accuracy: 0.5643\n", + "203/203 [==============================] - 0s 2ms/step - loss: 0.9144 - accuracy: 0.9264\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.97it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6840 - accuracy: 0.5688\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9214 - accuracy: 0.9269\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 38.46it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6814 - accuracy: 0.5693\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9317 - accuracy: 0.9273\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.82it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6813 - accuracy: 0.5672\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.8722 - accuracy: 0.9267\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 36.44it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6821 - accuracy: 0.5674\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.8757 - accuracy: 0.9272\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.66it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6812 - accuracy: 0.5705\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.8934 - accuracy: 0.9269\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 38.95it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6799 - accuracy: 0.5734\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.8888 - accuracy: 0.9270\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 36.74it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6806 - accuracy: 0.5736\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9002 - accuracy: 0.9262\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 36.14it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6805 - accuracy: 0.5724\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9191 - accuracy: 0.9262\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.04it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6818 - accuracy: 0.5656\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.8944 - accuracy: 0.9267\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.81it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6790 - accuracy: 0.5806\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9259 - accuracy: 0.9272\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 35.16it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6798 - accuracy: 0.5781\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9072 - accuracy: 0.9278\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 36.14it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6795 - accuracy: 0.5745\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.8916 - accuracy: 0.9276\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 36.21it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " 1/324 [..............................] - ETA: 0s - loss: 0.6607 - accuracy: 0.6562WARNING:tensorflow:Callbacks method `on_train_batch_begin` is slow compared to the batch time (batch time: 0.0000s vs `on_train_batch_begin` time: 0.0010s). Check your callbacks.\n", + "WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0000s vs `on_train_batch_end` time: 0.0010s). Check your callbacks.\n", + "324/324 [==============================] - 0s 1ms/step - loss: 0.6781 - accuracy: 0.5782\n", + "203/203 [==============================] - 0s 1ms/step - loss: 0.9033 - accuracy: 0.9273\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 37.50it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6796 - accuracy: 0.5812\n", + "203/203 [==============================] - 0s 2ms/step - loss: 0.9343 - accuracy: 0.9278\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 26.87it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "324/324 [==============================] - 0s 1ms/step - loss: 0.6793 - accuracy: 0.5796\n", + "203/203 [==============================] - 0s 2ms/step - loss: 0.9035 - accuracy: 0.9264\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 25.64it/s]\n" + ] + } + ], + "source": [ + "epochs = 50\n", + "batch_size = 32\n", + "\n", + "clear_session()\n", + "ganblr_ext = GANBLR_Ext()\n", + "ganblr_ext.fit(X_train, y_train, epochs=epochs, warm_up_epochs=5,categories_=categories)\n", + "ganblr = GANBLR()\n", + "ganblr.fit(X_train, y_train, epochs=epochs, warm_up_epochs=5,categories_=categories)" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "metadata": {}, + "outputs": [], + "source": [ + "# g_loss_log = [d['loss'] for d in ganblr.g_history]\n", + "# d_loss_log = [d['loss'] for d in ganblr.d_history]\n", + "# \n", + "# plt.plot(g_loss_log, label='generator loss')\n", + "# plt.plot(d_loss_log, label='discriminator loss')\n", + "# plt.legend()" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "metadata": {}, + "outputs": [], + "source": [ + "# g_loss_log = [d['loss'] for d in ganblr_ext.g_history]\n", + "# d_loss_log = [d['loss'] for d in ganblr_ext.d_history]\n", + "# \n", + "# plt.plot(g_loss_log, label='generator loss')\n", + "# plt.plot(d_loss_log, label='discriminator loss')\n", + "# plt.legend()" + ] + }, + { + "cell_type": "code", + "execution_count": 101, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 28.04it/s]\n", + "Generating for node: 6: 100%|██████████| 9/9 [00:00<00:00, 28.21it/s]\n" + ] + } + ], + "source": [ + "ganblr_syn_data = ganblr.generate(len(X_train), ohe=True)\n", + "ganblr_ext_syn_data = ganblr_ext.generate(len(X_train), ohe=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.linear_model import LogisticRegression\n", + "from sklearn.metrics import accuracy_score" + ] + }, + { + "cell_type": "code", + "execution_count": 103, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\rukam\\AppData\\Local\\Programs\\Python\\Python38\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:763: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n", + "C:\\Users\\rukam\\AppData\\Local\\Programs\\Python\\Python38\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:763: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n", + "C:\\Users\\rukam\\AppData\\Local\\Programs\\Python\\Python38\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:763: ConvergenceWarning: lbfgs failed to converge (status=1):\n", + "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", + "\n", + "Increase the number of iterations (max_iter) or scale the data as shown in:\n", + " https://scikit-learn.org/stable/modules/preprocessing.html\n", + "Please also refer to the documentation for alternative solver options:\n", + " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", + " n_iter_i = _check_optimize_result(\n" + ] + } + ], + "source": [ + "ohe = OneHotEncoder(categories=categories, sparse=False)\n", + "lr2 = LogisticRegression().fit(*ganblr_ext_syn_data)\n", + "lr3 = LogisticRegression().fit(ohe.fit_transform(X_train), y_train)\n", + "lr1 = LogisticRegression().fit(*ganblr_syn_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 106, + "metadata": {}, + "outputs": [], + "source": [ + "pred1 = lr1.predict(ohe.fit_transform(X_test))\n", + "pred2 = lr2.predict(ohe.fit_transform(X_test))\n", + "pred3 = lr3.predict(ohe.fit_transform(X_test))" + ] + }, + { + "cell_type": "code", + "execution_count": 107, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TSTR on GANBLR:\t 0.8447530864197531\n", + "TSTR on GANBLR Extend:\t 0.8356481481481481\n", + "TRTR:\t 0.9246913580246914\n" + ] + } + ], + "source": [ + "print('TSTR on GANBLR:\\t', accuracy_score(y_test, pred1))\n", + "print('TSTR on GANBLR Extend:\\t', accuracy_score(y_test, pred2))\n", + "print('TRTR:\\t', accuracy_score(y_test, pred3))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### result backup\n", + "\n", + "adult:\n", + "- TSTR on GANBLR:\t 0.7512796363785267\n", + "- TSTR on GANBLR Extend:\t 0.8077883788542648\n", + "- TRTR:\t 0.8730191228860407\n", + "\n", + "car\n", + "\n", + "- TSTR on GANBLR:\t 0.8148148148148148\n", + "- TSTR on GANBLR Extend:\t 0.8090277777777778\n", + "- TRTR:\t 0.90625\n", + "\n", + "nursery\n", + "\n", + "- TSTR on GANBLR:\t 0.8447530864197531\n", + "- TSTR on GANBLR Extend:\t 0.8356481481481481\n", + "- TRTR:\t 0.9246913580246914" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "colab": { + "name": "GAN-ELR.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +}