From a3641bf49c3b49d8b4e66a64c29d4ae9fe2d396f Mon Sep 17 00:00:00 2001 From: PredelinaAsya <54816157+PredelinaAsya@users.noreply.github.com> Date: Mon, 28 Jun 2021 22:13:06 +0500 Subject: [PATCH 1/3] Add notebook with testing results --- test/Final_test_results.ipynb | 868 ++++++++++++++++++++++++++++++++++ 1 file changed, 868 insertions(+) create mode 100644 test/Final_test_results.ipynb diff --git a/test/Final_test_results.ipynb b/test/Final_test_results.ipynb new file mode 100644 index 0000000..1cfb59a --- /dev/null +++ b/test/Final_test_results.ipynb @@ -0,0 +1,868 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Final test results.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 202 + }, + "id": "pNg202vKNgSe", + "outputId": "eb55ac65-b700-4ce0-d060-e02117393ec5" + }, + "source": [ + "import pandas as pd\n", + "\n", + "data = pd.read_csv('004_human_set_3000.csv')\n", + "data.head()" + ], + "execution_count": 3, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
worddefs
0абажурверхняя часть лампы
1абажурчасть лампы
2абонементэто карточка, которая позволяет тебе ходить в ...
3абрикосмаленький оранжевый фрукт
4абрикосфрукт
\n", + "
" + ], + "text/plain": [ + " word defs\n", + "0 абажур верхняя часть лампы\n", + "1 абажур часть лампы\n", + "2 абонемент это карточка, которая позволяет тебе ходить в ...\n", + "3 абрикос маленький оранжевый фрукт\n", + "4 абрикос фрукт" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 3 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "BBx8K0ImTiHZ", + "outputId": "1d310947-9504-45ac-ce44-df687a73eed6" + }, + "source": [ + "!pip install -U spacy\n", + "!python -m spacy download ru_core_news_lg " + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Requirement already up-to-date: spacy in /usr/local/lib/python3.7/dist-packages (3.0.6)\n", + "Requirement already satisfied, skipping upgrade: setuptools in /usr/local/lib/python3.7/dist-packages (from spacy) (57.0.0)\n", + "Requirement already satisfied, skipping upgrade: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (1.0.5)\n", + "Requirement already satisfied, skipping upgrade: numpy>=1.15.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (1.19.5)\n", + "Requirement already satisfied, skipping upgrade: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (20.9)\n", + "Requirement already satisfied, skipping upgrade: catalogue<2.1.0,>=2.0.3 in /usr/local/lib/python3.7/dist-packages (from spacy) (2.0.4)\n", + "Requirement already satisfied, skipping upgrade: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (4.41.1)\n", + "Requirement already satisfied, skipping upgrade: blis<0.8.0,>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (0.4.1)\n", + "Requirement already satisfied, skipping upgrade: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy) (3.0.5)\n", + "Requirement already satisfied, skipping upgrade: srsly<3.0.0,>=2.4.1 in /usr/local/lib/python3.7/dist-packages (from spacy) (2.4.1)\n", + "Requirement already satisfied, skipping upgrade: pydantic<1.8.0,>=1.7.1 in /usr/local/lib/python3.7/dist-packages (from spacy) (1.7.4)\n", + "Requirement already satisfied, skipping upgrade: spacy-legacy<3.1.0,>=3.0.4 in /usr/local/lib/python3.7/dist-packages (from spacy) (3.0.6)\n", + "Requirement already satisfied, skipping upgrade: typing-extensions<4.0.0.0,>=3.7.4; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from spacy) (3.7.4.3)\n", + "Requirement already satisfied, skipping upgrade: wasabi<1.1.0,>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from spacy) (0.8.2)\n", + "Requirement already satisfied, skipping upgrade: pathy>=0.3.5 in /usr/local/lib/python3.7/dist-packages (from spacy) (0.6.0)\n", + "Requirement already satisfied, skipping upgrade: jinja2 in /usr/local/lib/python3.7/dist-packages (from spacy) (2.11.3)\n", + "Requirement already satisfied, skipping upgrade: typer<0.4.0,>=0.3.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (0.3.2)\n", + "Requirement already satisfied, skipping upgrade: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy) (2.0.5)\n", + "Requirement already satisfied, skipping upgrade: thinc<8.1.0,>=8.0.3 in /usr/local/lib/python3.7/dist-packages (from spacy) (8.0.6)\n", + "Requirement already satisfied, skipping upgrade: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (2.23.0)\n", + "Requirement already satisfied, skipping upgrade: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->spacy) (2.4.7)\n", + "Requirement already satisfied, skipping upgrade: zipp>=0.5; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from catalogue<2.1.0,>=2.0.3->spacy) (3.4.1)\n", + "Requirement already satisfied, skipping upgrade: smart-open<6.0.0,>=5.0.0 in /usr/local/lib/python3.7/dist-packages (from pathy>=0.3.5->spacy) (5.1.0)\n", + "Requirement already satisfied, skipping upgrade: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from jinja2->spacy) (2.0.1)\n", + "Requirement already satisfied, skipping upgrade: click<7.2.0,>=7.1.1 in /usr/local/lib/python3.7/dist-packages (from typer<0.4.0,>=0.3.0->spacy) (7.1.2)\n", + "Requirement already satisfied, skipping upgrade: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (3.0.4)\n", + "Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (1.24.3)\n", + "Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (2021.5.30)\n", + "Requirement already satisfied, skipping upgrade: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (2.10)\n", + "2021-06-28 14:51:32.756091: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0\n", + "Requirement already satisfied: ru-core-news-lg==3.0.0 from https://github.com/explosion/spacy-models/releases/download/ru_core_news_lg-3.0.0/ru_core_news_lg-3.0.0-py3-none-any.whl#egg=ru_core_news_lg==3.0.0 in /usr/local/lib/python3.7/dist-packages (3.0.0)\n", + "Requirement already satisfied: spacy<3.1.0,>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from ru-core-news-lg==3.0.0) (3.0.6)\n", + "Requirement already satisfied: pymorphy2>=0.9 in /usr/local/lib/python3.7/dist-packages (from ru-core-news-lg==3.0.0) (0.9.1)\n", + "Requirement already satisfied: thinc<8.1.0,>=8.0.3 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (8.0.6)\n", + "Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.4 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (3.0.6)\n", + "Requirement already satisfied: wasabi<1.1.0,>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (0.8.2)\n", + "Requirement already satisfied: numpy>=1.15.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (1.19.5)\n", + "Requirement already satisfied: catalogue<2.1.0,>=2.0.3 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (2.0.4)\n", + "Requirement already satisfied: typing-extensions<4.0.0.0,>=3.7.4; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (3.7.4.3)\n", + "Requirement already satisfied: pydantic<1.8.0,>=1.7.1 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (1.7.4)\n", + "Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (4.41.1)\n", + "Requirement already satisfied: srsly<3.0.0,>=2.4.1 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (2.4.1)\n", + "Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (2.0.5)\n", + "Requirement already satisfied: typer<0.4.0,>=0.3.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (0.3.2)\n", + "Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (2.23.0)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (20.9)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (2.11.3)\n", + "Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (3.0.5)\n", + "Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (1.0.5)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (57.0.0)\n", + "Requirement already satisfied: blis<0.8.0,>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (0.4.1)\n", + "Requirement already satisfied: pathy>=0.3.5 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (0.6.0)\n", + "Requirement already satisfied: dawg-python>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from pymorphy2>=0.9->ru-core-news-lg==3.0.0) (0.7.2)\n", + "Requirement already satisfied: docopt>=0.6 in /usr/local/lib/python3.7/dist-packages (from pymorphy2>=0.9->ru-core-news-lg==3.0.0) (0.6.2)\n", + "Requirement already satisfied: pymorphy2-dicts-ru<3.0,>=2.4 in /usr/local/lib/python3.7/dist-packages (from pymorphy2>=0.9->ru-core-news-lg==3.0.0) (2.4.417127.4579844)\n", + "Requirement already satisfied: zipp>=0.5; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from catalogue<2.1.0,>=2.0.3->spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (3.4.1)\n", + "Requirement already satisfied: click<7.2.0,>=7.1.1 in /usr/local/lib/python3.7/dist-packages (from typer<0.4.0,>=0.3.0->spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (7.1.2)\n", + "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (3.0.4)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (2021.5.30)\n", + "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (2.10)\n", + "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (1.24.3)\n", + "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (2.4.7)\n", + "Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from jinja2->spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (2.0.1)\n", + "Requirement already satisfied: smart-open<6.0.0,>=5.0.0 in /usr/local/lib/python3.7/dist-packages (from pathy>=0.3.5->spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (5.1.0)\n", + "\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n", + "You can now load the package via spacy.load('ru_core_news_lg')\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "8tKITnFlJ0q2" + }, + "source": [ + "import tensorflow.keras\n", + "import spacy\n", + "\n", + "classifier = tensorflow.keras.models.load_model('best_model_lstm.h5')\n", + "nlp = spacy.load('ru_core_news_lg')" + ], + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "g5OS53LQMpQd" + }, + "source": [ + "import numpy as np\n", + "\n", + "\n", + "numb = {\"NOUN\": 1, \"ADJ\": 2, \"ADP\": 3, \"VERB\": 4,\n", + " \"CCONJ\": 5, \"PRON\": 6, \"ADV\": 7, \"DET\": 8,\n", + " \"NUM\": 9, \"PROPN\": 1, \"SCONJ\": 5, \"X\": 10,\n", + " \"AUX\": 10, \"PUNCT\": 11}\n", + "\n", + "max_code_len = 52\n", + "\n", + "\n", + "def encode_def_to_classifier_input(sentence, nlp_model):\n", + " doc = nlp_model(sentence)\n", + " curr_arr = []\n", + "\n", + " for i in range(len(doc)):\n", + " word = doc[i]\n", + " pos_tag = word.pos_\n", + " if pos_tag not in numb:\n", + " curr_arr.append(numb[\"X\"])\n", + " continue\n", + " curr_arr.append(numb[pos_tag])\n", + "\n", + " null_prefix = [0] * (max_code_len - len(curr_arr))\n", + "\n", + " return np.array([null_prefix + curr_arr])" + ], + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9-YbG60TReYB", + "outputId": "26459c61-5fc6-417f-b4fa-190b14620a40" + }, + "source": [ + "question = \"мать, дочь, бабушка\"\n", + "encode_def = encode_def_to_classifier_input(question, nlp)\n", + "type_arr = np.argmax(classifier.predict_on_batch(encode_def))\n", + "print(type_arr)" + ], + "execution_count": 7, + "outputs": [ + { + "output_type": "stream", + "text": [ + "1\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2Y47hhctdgbD", + "outputId": "f61894d7-51c0-47b5-b0dd-8503d4035ea7" + }, + "source": [ + "import nltk\n", + "\n", + "nltk.download('stopwords')\n", + "nltk.download('punkt')" + ], + "execution_count": 8, + "outputs": [ + { + "output_type": "stream", + "text": [ + "[nltk_data] Downloading package stopwords to /root/nltk_data...\n", + "[nltk_data] Package stopwords is already up-to-date!\n", + "[nltk_data] Downloading package punkt to /root/nltk_data...\n", + "[nltk_data] Package punkt is already up-to-date!\n" + ], + "name": "stdout" + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "True" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 8 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "nShFZe6DUzfH" + }, + "source": [ + "import re\n", + "\n", + "from nltk.tokenize import word_tokenize\n", + "from nltk.corpus import stopwords\n", + "\n", + "\n", + "class SumWords:\n", + " '''\n", + " Searching similar words for sentence words sum\n", + " '''\n", + "\n", + " def __init__(self, a_model, a_tops=10):\n", + " self.tops = a_tops\n", + " self.r = re.compile(\"[а-яА-Я-]+\")\n", + " self.stops = stopwords.words(\"russian\")\n", + " self.model = a_model\n", + " self.vocab = self.model.wv.vocab\n", + "\n", + " def is_in_vocab(self, word):\n", + " return word in self.vocab\n", + "\n", + " def get_words(self, sentence, prefix):\n", + " '''\n", + " Returns a list of similar words with fixed prefix.\n", + " List is ranged by similarity level.\n", + " '''\n", + "\n", + " # prepare the list of words from the sentence\n", + " words = [word.lower() for word in word_tokenize(sentence) if word.isalpha()]\n", + " words = [w for w in filter(self.r.match, words)]\n", + " words = [word for word in words if word not in self.stops]\n", + "\n", + " text = ' '.join(words)\n", + " doc = nlp(text)\n", + " words = [word.lemma_ for word in doc]\n", + "\n", + " words = [word for word in words if word in self.vocab]\n", + " if words != []:\n", + " sum_similar = self.model.most_similar(positive=words, topn=self.tops)\n", + " res = [i[0] for i in sum_similar if prefix == i[0][:len(prefix)]]\n", + " else:\n", + " return []\n", + " \n", + " return res" + ], + "execution_count": 9, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "c7ct8T_kpyLv" + }, + "source": [ + "from gensim.models import Word2Vec\n", + "\n", + "word2vec_mod = Word2Vec.load('word2vec.model')\n", + "sum_words = SumWords(word2vec_mod, 20)" + ], + "execution_count": 10, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "n2FNSKCLTYsp", + "outputId": "4c6124ea-b2d4-4077-af4f-c73044d70efd" + }, + "source": [ + "question = \"верхняя часть лампы\"\n", + "lst = sum_words.get_words(question, 'а')\n", + "print(lst)" + ], + "execution_count": 11, + "outputs": [ + { + "output_type": "stream", + "text": [ + "['абажур', 'антаблемент']\n" + ], + "name": "stdout" + }, + { + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:39: DeprecationWarning: Call to deprecated `most_similar` (Method will be removed in 4.0.0, use self.wv.most_similar() instead).\n" + ], + "name": "stderr" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fUj7SCUrPM-s", + "outputId": "82cb9aec-aa4a-4dd9-b15d-268d4b19fdda" + }, + "source": [ + "def get_score_0(prefix_len):\n", + " n_test, n_success = 0, 0\n", + "\n", + " for ind in data.index:\n", + " n_test += 1\n", + "\n", + " word, curr_def = data['word'][ind], data['defs'][ind]\n", + "\n", + " prefix = word[:min(prefix_len, len(word))]\n", + " ans_list = sum_words.get_words(curr_def, prefix)\n", + " if len(ans_list):\n", + " answer = ans_list[0]\n", + " else:\n", + " answer = ''\n", + "\n", + " if answer == word:\n", + " n_success += 1\n", + "\n", + " return n_success / n_test * 100\n", + "\n", + "\n", + "results = {}\n", + "for i in range(0, 7):\n", + " score = get_score_0(i)\n", + " results[i] = score\n", + "\n", + "print(results)" + ], + "execution_count": 12, + "outputs": [ + { + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:39: DeprecationWarning: Call to deprecated `most_similar` (Method will be removed in 4.0.0, use self.wv.most_similar() instead).\n" + ], + "name": "stderr" + }, + { + "output_type": "stream", + "text": [ + "{0: 17.478411053540587, 1: 31.537132987910187, 2: 35.682210708117445, 3: 37.167530224525045, 4: 37.54749568221071, 5: 37.82383419689119, 6: 38.20379965457686}\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "cueWUf0hQvG1" + }, + "source": [ + "import gensim\n", + "from pymystem3 import Mystem\n", + "import re\n", + "\n", + "\n", + "class Word2vec:\n", + " def get_word_vector(self, word):\n", + " raise Exception('Not implemented')\n", + "\n", + "\n", + "class FastText(Word2vec):\n", + " def __init__(self, path):\n", + " self.model = gensim.models.Word2Vec.load(path)\n", + " \n", + " def get_word_vector(self, word):\n", + " return self.model[word] \n", + "\n", + "\n", + "class Text2Lemms:\n", + " def __init__(self):\n", + " self.mystem = Mystem()\n", + " \n", + " def get_lemms(self, text, tag=None):\n", + " list_lemm = []\n", + " for lemma in self.mystem.analyze(text):\n", + " if 'analysis' in lemma and len(lemma['analysis']):\n", + " analysis = lemma['analysis'][0]\n", + " if analysis.get('qual', None) == 'bastard':\n", + " continue\n", + " pos_tag = re.match('[A-Z]+', analysis['gr']).group(0)\n", + " if tag and pos_tag==tag:\n", + " list_lemm.append(analysis['lex'])\n", + " elif not tag:\n", + " list_lemm.append({'lex': analysis['lex'], 'pos': pos_tag})\n", + " return list_lemm\n", + "\n", + "\n", + "class WordTrie:\n", + " '''\n", + " How use \n", + "\n", + " wt = WordTrie(FastText())\n", + " wt.build_dict(['word', 'world', 'cat', 'cats'])\n", + " for word, vector in wt.search_by_prefix('cats'): \n", + " print(word)\n", + " '''\n", + " def __init__(self, word2vec:Word2vec):\n", + " self.root = _Node('*')\n", + " self.get_vector = word2vec.get_word_vector\n", + "\n", + " def add(self, word):\n", + " tmp_node = self.root\n", + "\n", + " for char in word:\n", + " child = tmp_node.children.get(char, None)\n", + " if child is None:\n", + " child = _Node(char)\n", + " tmp_node.children[char] = child\n", + " child.parent = tmp_node\n", + " tmp_node = child\n", + "\n", + " tmp_node.value = self.get_vector(word)\n", + "\n", + " def build_dict(self, words):\n", + " for word in words:\n", + " self.add(word)\n", + " return self\n", + "\n", + " def search_by_prefix(self, prefix):\n", + " tmp_node = self.root\n", + "\n", + " for char in prefix:\n", + " tmp_node = tmp_node.children.get(char, None)\n", + " if tmp_node is None:\n", + " return\n", + " yield from tmp_node.get_childs()\n", + "\n", + "\n", + "#####################\n", + "##### for WordTrie\n", + "#####################\n", + "\n", + "\n", + "class _Node:\n", + " def __init__(self, char: str):\n", + " self.char = char\n", + " self.children = {}\n", + " self.value = None\n", + " self.parent = None\n", + "\n", + " def _get_prefix(self):\n", + " tmp_node = self\n", + " prefix = ''\n", + " while tmp_node.parent:\n", + " prefix = tmp_node.char + prefix\n", + " tmp_node = tmp_node.parent\n", + "\n", + " return prefix\n", + "\n", + " def get_childs(self):\n", + " stack = [(self, self._get_prefix())]\n", + "\n", + " while stack:\n", + " tmp_node, prefix = stack.pop(0)\n", + " if tmp_node.value is not None:\n", + " yield prefix, tmp_node.value\n", + " \n", + " for char, child in tmp_node.children.items():\n", + " stack.append((child, prefix+char))\n" + ], + "execution_count": 13, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "0xAyngNvQVzw" + }, + "source": [ + "import pickle\n", + "\n", + "with open('tree.pickle', 'rb') as f:\n", + " trie = pickle.load(f)" + ], + "execution_count": 14, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9Y9Dr4FKVeX3", + "outputId": "fbe2727d-416d-4b11-dcb5-f07601683bd1" + }, + "source": [ + "!pip install wiki-ru-wordnet" + ], + "execution_count": 15, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Requirement already satisfied: wiki-ru-wordnet in /usr/local/lib/python3.7/dist-packages (1.0.3)\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "lvwI57poVRsn" + }, + "source": [ + "from wiki_ru_wordnet import WikiWordnet" + ], + "execution_count": 16, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "pvW8Z4xPVzsi" + }, + "source": [ + "import pymorphy2\n", + "\n", + "\n", + "morph = pymorphy2.MorphAnalyzer()\n", + "\n", + "\n", + "def search_simple_bigramm(lemma_list):\n", + " if not lemma_list:\n", + " return\n", + "\n", + " for w1, w2 in zip(lemma_list, lemma_list[1:]):\n", + " if w1['pos']=='A' and w2['pos']=='S':\n", + " # parsing word - receiving list of possible values\n", + " w1_vars = morph.parse(w1['lex'])\n", + " # choosing elements with compatible POS\n", + " w1_vars = [i for i in w1_vars if i.tag.POS == 'ADJF']\n", + "\n", + " word2 = w2['lex']\n", + " w2_vars = morph.parse(word2)\n", + " w2_vars = [i for i in w2_vars if i.tag.POS == 'NOUN']\n", + "\n", + " if len(w1_vars) == 0 or len(w2_vars) == 0:\n", + " return\n", + " gender = w2_vars[0].tag.gender\n", + "\n", + " try:\n", + " word1 = w1_vars[0].inflect({gender}).word\n", + " except ValueError:\n", + " word1 = w1_vars[0].word\n", + " return word1, word2\n" + ], + "execution_count": 17, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "0hL5HIhkUloX" + }, + "source": [ + "class HypWords:\n", + " # hyponym and hypernym words\n", + " def __init__(self, prefix_trie=None):\n", + " self.wikiwordnet = WikiWordnet()\n", + " #self.prefix_trie = get_prefix_trie() # перенести потом в бота, инициализировать перед запуском, как и остальные модели\n", + " self.prefix_trie = prefix_trie\n", + " self.text2LemmsModel = Text2Lemms()\n", + "\n", + " def _get_hyp_with_prefix(self, word, words_with_prefix):\n", + " hyp_w = get_hyponym_and_hypernym(self.wikiwordnet, word)\n", + " answer = hyp_w & words_with_prefix\n", + " if answer:\n", + " return list(answer)\n", + " return []\n", + "\n", + " def is_in_vocab(self, word):\n", + " sets = self.wikiwordnet.get_synsets(word)\n", + " return len(sets) != 0\n", + "\n", + " def get_words(self, sentence, prefix):\n", + " list_lex = self.text2LemmsModel.get_lemms(sentence)\n", + " bigramm_w = search_simple_bigramm(list_lex)\n", + " words_with_prefix = set(w[0] for w in self.prefix_trie.search_by_prefix(prefix))\n", + " word = None\n", + "\n", + " if bigramm_w:\n", + " bigramm = ' '.join(bigramm_w)\n", + " ans = self._get_hyp_with_prefix(bigramm, words_with_prefix)\n", + " if ans:\n", + " return ans\n", + " word = bigramm_w[1]\n", + " else:\n", + " for w in list_lex:\n", + " if w['pos'] == 'S':\n", + " word = w['lex']\n", + " break\n", + "\n", + " return self._get_hyp_with_prefix(word, words_with_prefix)\n", + "\n", + "\n", + "def get_hyponym_and_hypernym(wikiwordnet, word):\n", + " synsets = wikiwordnet.get_synsets(word)\n", + " set_words = set()\n", + "\n", + " if not synsets:\n", + " return set_words\n", + "\n", + " synset1 = synsets[0]\n", + "\n", + " for hyponym in wikiwordnet.get_hyponyms(synset1):\n", + " set_words |= { w.lemma() for w in hyponym.get_words()}\n", + "\n", + " for hypernym in wikiwordnet.get_hypernyms(synset1):\n", + " set_words |= { w.lemma() for w in hypernym.get_words()}\n", + "\n", + " return set_words\n" + ], + "execution_count": 18, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "Tbx59RbOV5Gs" + }, + "source": [ + "hyp_words = HypWords(prefix_trie=trie)" + ], + "execution_count": 19, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "wi7COOMY9GPI", + "outputId": "a87723b0-cfed-4d5f-e55d-84547466a1bf" + }, + "source": [ + "question = \"лампа\"\n", + "lst = hyp_words.get_words(question, 'л')\n", + "print(lst)" + ], + "execution_count": 20, + "outputs": [ + { + "output_type": "stream", + "text": [ + "['лампада']\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "lP0Amz0bWoR2" + }, + "source": [ + "LIST_MODELS = [sum_words, hyp_words]" + ], + "execution_count": 21, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "3VNMdun5IpwY" + }, + "source": [ + "def get_score(prefix_len):\n", + " n_test, n_success = 0, 0\n", + "\n", + " for ind in data.index:\n", + " n_test += 1\n", + "\n", + " word, curr_def = data['word'][ind], data['defs'][ind]\n", + "\n", + " encode_def = encode_def_to_classifier_input(curr_def, nlp)\n", + " def_type = np.argmax(classifier.predict_on_batch(encode_def))\n", + "\n", + " list_words = []\n", + " prefix = word[:min(prefix_len, len(word))]\n", + " for model in LIST_MODELS:\n", + " ans_words = model.get_words(question, prefix)\n", + " if len(ans_words):\n", + " list_words.append(ans_words[0])\n", + "\n", + " index = 1 if def_type == 1 and len(list_words) == 2 else 0\n", + " if index < len(list_words):\n", + " answer = list_words[index]\n", + " elif len(list_words):\n", + " answer = list_words[0]\n", + " else:\n", + " answer = \"\"\n", + "\n", + " if answer == word:\n", + " n_success += 1\n", + "\n", + " return n_success / n_test * 100" + ], + "execution_count": 22, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "pq5glqGRYvIF", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "974b6fb4-6429-4743-f27d-ac1c6d86b1c1" + }, + "source": [ + "results = {}\n", + "for i in range(0, 7):\n", + " score = get_score(i)\n", + " results[i] = score\n", + "\n", + "print(results)" + ], + "execution_count": 23, + "outputs": [ + { + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:39: DeprecationWarning: Call to deprecated `most_similar` (Method will be removed in 4.0.0, use self.wv.most_similar() instead).\n" + ], + "name": "stderr" + }, + { + "output_type": "stream", + "text": [ + "{0: 0.03454231433506045, 1: 0.0690846286701209, 2: 0.17271157167530224, 3: 0.17271157167530224, 4: 0.17271157167530224, 5: 0.17271157167530224, 6: 0.17271157167530224}\n" + ], + "name": "stdout" + } + ] + } + ] +} \ No newline at end of file From c00a8d9b9b6edc24779a9400f3712c65750c4463 Mon Sep 17 00:00:00 2001 From: PredelinaAsya <54816157+PredelinaAsya@users.noreply.github.com> Date: Mon, 28 Jun 2021 22:24:10 +0500 Subject: [PATCH 2/3] Updated the file with the results from 18.12.20 --- test/Tests_results.ipynb | 67 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 66 insertions(+), 1 deletion(-) diff --git a/test/Tests_results.ipynb b/test/Tests_results.ipynb index 6654a2e..c28f279 100644 --- a/test/Tests_results.ipynb +++ b/test/Tests_results.ipynb @@ -15,6 +15,71 @@ "*Новые результаты предлагается добавлять на верх страницы, чтобы избежать проматывания при их поиске*" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 18.12.2020" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Тестовый датасет\n", + "\n", + "data/test_sets/004_human_set.csv - несловарный датасет, содержит около 3000 определений. Датасет собран осенью 2020 при работе бота на платформе Heroku." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Модели поиска ответа:\n", + "\n", + "* hyp_model - /src/models/hyp_model.py. На базе семантической сети wikiwordnet.\n", + " * лемматизация слов\n", + " * поиск базовой биграммы фразы\n", + " * поиск гипонимов и гиперонимов для слов базовой биграммы\n", + "\n", + "\n", + "* sum_model - /src/models/sum_model.py. Word2vec model - сжатая версия fasttext модели от RusVectores [ft_freqprune_100K_20K_pq_300.bin](https://github.com/avidale/compress-fasttext/releases/tag/v0.0.1).\n", + " * лемматизация слов\n", + " * сложение векторов слов\n", + " * поиск ближайшего среди слов с соответствующим префиксом\n", + " \n", + " \n", + "* sum_plus_hyp_model - объединение двух упомянутых моделей. На выход передаётся до двух слов - результаты от каждой модели. Таким образом мы проверяем, могут ли модели дополнять друг друга - распознают ли они разные определения.\n", + "\n", + "\n", + "* ens_sum_hyp - объединение двух упомянутых моделей. При обработке определения решается, ответ какой из моделей будет передан на выход. Выбор по числу существительных и глаголов в определении." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Метрика\n", + "\n", + "Проверяется точное соответствие ответа загаданному слову (в случае модели sum_plus_hyp_model - наличие загаданного слова в выходном списке).\n", + "\n", + "Финальный скор - процент слов, для которых проверка прошла успешно." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Таблица результатов\n", + "\n", + "Модель | it/s | Score, %\n", + "------------- | ----------- | -----\n", + "hyp_model | 86.63 | 7,9\n", + "sum_model | 1.72 | 16.65\n", + "sum_plus_hyp_model | 1.79 | 22\n", + "ens_sum_hyp | 1.38 | 17.6\n" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -176,7 +241,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.2" + "version": "3.8.5" } }, "nbformat": 4, From 1ae6bd5b6592b4b007918528e99da43722ef2ab5 Mon Sep 17 00:00:00 2001 From: PredelinaAsya <54816157+PredelinaAsya@users.noreply.github.com> Date: Mon, 28 Jun 2021 22:51:57 +0500 Subject: [PATCH 3/3] Test results 28.06 --- test/Tests_results.ipynb | 56 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/test/Tests_results.ipynb b/test/Tests_results.ipynb index c28f279..cd5f3cc 100644 --- a/test/Tests_results.ipynb +++ b/test/Tests_results.ipynb @@ -15,6 +15,60 @@ "*Новые результаты предлагается добавлять на верх страницы, чтобы избежать проматывания при их поиске*" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 28.06.2021" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Тестовый датасет\n", + "\n", + "data/test_sets/004_human_set.csv" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Модели поиска ответа:\n", + "\n", + "* sum_model - /src/models/sum_model.py. Word2vec model - обучена на смешанном датасете определений (формальных с википедии и пользовательских)\n", + " * лемматизация слов\n", + " * сложение векторов слов\n", + " * поиск ближайшего среди слов с соответствующим префиксом\n", + " \n", + " \n", + "* ens_sum_hyp - объединение двух моделей суммирования и гипонимов/гиперонимов (при построении префиксного дерева был использован обученный на смешанном датасете Word2vec). При обработке определения решается, ответ какой из моделей будет передан на выход. Выбор на основе классификатора определений (https://github.com/Alioth6/contact_game_bot/blob/master/src/models/Classification%20with%20LSTM.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Метрика\n", + "\n", + "Проверяется точное соответствие ответа загаданному слову при нескольких первых известных буквах в загаданном слове (от 0 до 4).\n", + "\n", + "Финальный скор - процент слов, для которых проверка прошла успешно." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Таблица результатов\n", + "\n", + "Модель | Знаем 0 букв; score, % | 1 буква; score, % | 2 буквы; score, % | 3 буквы; score, % | 4 буквы score, %\n", + "------------- | ----------- | ----------- | ----------- | ----------- | -----\n", + "sum_model | 17.48 | 31.54 | 35.68 | 37.17 | 37.55\n", + "ens_sum_hyp | 0.03 | 0.07 | 0.17 | 0.17 | 0.17" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -241,7 +295,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.7.3" } }, "nbformat": 4,