diff --git a/test/Final_test_results.ipynb b/test/Final_test_results.ipynb
new file mode 100644
index 0000000..1cfb59a
--- /dev/null
+++ b/test/Final_test_results.ipynb
@@ -0,0 +1,868 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "Final test results.ipynb",
+ "provenance": [],
+ "collapsed_sections": []
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 202
+ },
+ "id": "pNg202vKNgSe",
+ "outputId": "eb55ac65-b700-4ce0-d060-e02117393ec5"
+ },
+ "source": [
+ "import pandas as pd\n",
+ "\n",
+ "data = pd.read_csv('004_human_set_3000.csv')\n",
+ "data.head()"
+ ],
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " word | \n",
+ " defs | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " абажур | \n",
+ " верхняя часть лампы | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " абажур | \n",
+ " часть лампы | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " абонемент | \n",
+ " это карточка, которая позволяет тебе ходить в ... | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " абрикос | \n",
+ " маленький оранжевый фрукт | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " абрикос | \n",
+ " фрукт | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " word defs\n",
+ "0 абажур верхняя часть лампы\n",
+ "1 абажур часть лампы\n",
+ "2 абонемент это карточка, которая позволяет тебе ходить в ...\n",
+ "3 абрикос маленький оранжевый фрукт\n",
+ "4 абрикос фрукт"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 3
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "BBx8K0ImTiHZ",
+ "outputId": "1d310947-9504-45ac-ce44-df687a73eed6"
+ },
+ "source": [
+ "!pip install -U spacy\n",
+ "!python -m spacy download ru_core_news_lg "
+ ],
+ "execution_count": 4,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Requirement already up-to-date: spacy in /usr/local/lib/python3.7/dist-packages (3.0.6)\n",
+ "Requirement already satisfied, skipping upgrade: setuptools in /usr/local/lib/python3.7/dist-packages (from spacy) (57.0.0)\n",
+ "Requirement already satisfied, skipping upgrade: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (1.0.5)\n",
+ "Requirement already satisfied, skipping upgrade: numpy>=1.15.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (1.19.5)\n",
+ "Requirement already satisfied, skipping upgrade: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (20.9)\n",
+ "Requirement already satisfied, skipping upgrade: catalogue<2.1.0,>=2.0.3 in /usr/local/lib/python3.7/dist-packages (from spacy) (2.0.4)\n",
+ "Requirement already satisfied, skipping upgrade: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (4.41.1)\n",
+ "Requirement already satisfied, skipping upgrade: blis<0.8.0,>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (0.4.1)\n",
+ "Requirement already satisfied, skipping upgrade: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy) (3.0.5)\n",
+ "Requirement already satisfied, skipping upgrade: srsly<3.0.0,>=2.4.1 in /usr/local/lib/python3.7/dist-packages (from spacy) (2.4.1)\n",
+ "Requirement already satisfied, skipping upgrade: pydantic<1.8.0,>=1.7.1 in /usr/local/lib/python3.7/dist-packages (from spacy) (1.7.4)\n",
+ "Requirement already satisfied, skipping upgrade: spacy-legacy<3.1.0,>=3.0.4 in /usr/local/lib/python3.7/dist-packages (from spacy) (3.0.6)\n",
+ "Requirement already satisfied, skipping upgrade: typing-extensions<4.0.0.0,>=3.7.4; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from spacy) (3.7.4.3)\n",
+ "Requirement already satisfied, skipping upgrade: wasabi<1.1.0,>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from spacy) (0.8.2)\n",
+ "Requirement already satisfied, skipping upgrade: pathy>=0.3.5 in /usr/local/lib/python3.7/dist-packages (from spacy) (0.6.0)\n",
+ "Requirement already satisfied, skipping upgrade: jinja2 in /usr/local/lib/python3.7/dist-packages (from spacy) (2.11.3)\n",
+ "Requirement already satisfied, skipping upgrade: typer<0.4.0,>=0.3.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (0.3.2)\n",
+ "Requirement already satisfied, skipping upgrade: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy) (2.0.5)\n",
+ "Requirement already satisfied, skipping upgrade: thinc<8.1.0,>=8.0.3 in /usr/local/lib/python3.7/dist-packages (from spacy) (8.0.6)\n",
+ "Requirement already satisfied, skipping upgrade: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (2.23.0)\n",
+ "Requirement already satisfied, skipping upgrade: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->spacy) (2.4.7)\n",
+ "Requirement already satisfied, skipping upgrade: zipp>=0.5; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from catalogue<2.1.0,>=2.0.3->spacy) (3.4.1)\n",
+ "Requirement already satisfied, skipping upgrade: smart-open<6.0.0,>=5.0.0 in /usr/local/lib/python3.7/dist-packages (from pathy>=0.3.5->spacy) (5.1.0)\n",
+ "Requirement already satisfied, skipping upgrade: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from jinja2->spacy) (2.0.1)\n",
+ "Requirement already satisfied, skipping upgrade: click<7.2.0,>=7.1.1 in /usr/local/lib/python3.7/dist-packages (from typer<0.4.0,>=0.3.0->spacy) (7.1.2)\n",
+ "Requirement already satisfied, skipping upgrade: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (3.0.4)\n",
+ "Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (1.24.3)\n",
+ "Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (2021.5.30)\n",
+ "Requirement already satisfied, skipping upgrade: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (2.10)\n",
+ "2021-06-28 14:51:32.756091: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0\n",
+ "Requirement already satisfied: ru-core-news-lg==3.0.0 from https://github.com/explosion/spacy-models/releases/download/ru_core_news_lg-3.0.0/ru_core_news_lg-3.0.0-py3-none-any.whl#egg=ru_core_news_lg==3.0.0 in /usr/local/lib/python3.7/dist-packages (3.0.0)\n",
+ "Requirement already satisfied: spacy<3.1.0,>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from ru-core-news-lg==3.0.0) (3.0.6)\n",
+ "Requirement already satisfied: pymorphy2>=0.9 in /usr/local/lib/python3.7/dist-packages (from ru-core-news-lg==3.0.0) (0.9.1)\n",
+ "Requirement already satisfied: thinc<8.1.0,>=8.0.3 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (8.0.6)\n",
+ "Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.4 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (3.0.6)\n",
+ "Requirement already satisfied: wasabi<1.1.0,>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (0.8.2)\n",
+ "Requirement already satisfied: numpy>=1.15.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (1.19.5)\n",
+ "Requirement already satisfied: catalogue<2.1.0,>=2.0.3 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (2.0.4)\n",
+ "Requirement already satisfied: typing-extensions<4.0.0.0,>=3.7.4; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (3.7.4.3)\n",
+ "Requirement already satisfied: pydantic<1.8.0,>=1.7.1 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (1.7.4)\n",
+ "Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (4.41.1)\n",
+ "Requirement already satisfied: srsly<3.0.0,>=2.4.1 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (2.4.1)\n",
+ "Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (2.0.5)\n",
+ "Requirement already satisfied: typer<0.4.0,>=0.3.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (0.3.2)\n",
+ "Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (2.23.0)\n",
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (20.9)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (2.11.3)\n",
+ "Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (3.0.5)\n",
+ "Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (1.0.5)\n",
+ "Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (57.0.0)\n",
+ "Requirement already satisfied: blis<0.8.0,>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (0.4.1)\n",
+ "Requirement already satisfied: pathy>=0.3.5 in /usr/local/lib/python3.7/dist-packages (from spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (0.6.0)\n",
+ "Requirement already satisfied: dawg-python>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from pymorphy2>=0.9->ru-core-news-lg==3.0.0) (0.7.2)\n",
+ "Requirement already satisfied: docopt>=0.6 in /usr/local/lib/python3.7/dist-packages (from pymorphy2>=0.9->ru-core-news-lg==3.0.0) (0.6.2)\n",
+ "Requirement already satisfied: pymorphy2-dicts-ru<3.0,>=2.4 in /usr/local/lib/python3.7/dist-packages (from pymorphy2>=0.9->ru-core-news-lg==3.0.0) (2.4.417127.4579844)\n",
+ "Requirement already satisfied: zipp>=0.5; python_version < \"3.8\" in /usr/local/lib/python3.7/dist-packages (from catalogue<2.1.0,>=2.0.3->spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (3.4.1)\n",
+ "Requirement already satisfied: click<7.2.0,>=7.1.1 in /usr/local/lib/python3.7/dist-packages (from typer<0.4.0,>=0.3.0->spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (7.1.2)\n",
+ "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (3.0.4)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (2021.5.30)\n",
+ "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (2.10)\n",
+ "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (1.24.3)\n",
+ "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (2.4.7)\n",
+ "Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from jinja2->spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (2.0.1)\n",
+ "Requirement already satisfied: smart-open<6.0.0,>=5.0.0 in /usr/local/lib/python3.7/dist-packages (from pathy>=0.3.5->spacy<3.1.0,>=3.0.0->ru-core-news-lg==3.0.0) (5.1.0)\n",
+ "\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n",
+ "You can now load the package via spacy.load('ru_core_news_lg')\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "8tKITnFlJ0q2"
+ },
+ "source": [
+ "import tensorflow.keras\n",
+ "import spacy\n",
+ "\n",
+ "classifier = tensorflow.keras.models.load_model('best_model_lstm.h5')\n",
+ "nlp = spacy.load('ru_core_news_lg')"
+ ],
+ "execution_count": 5,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "g5OS53LQMpQd"
+ },
+ "source": [
+ "import numpy as np\n",
+ "\n",
+ "\n",
+ "numb = {\"NOUN\": 1, \"ADJ\": 2, \"ADP\": 3, \"VERB\": 4,\n",
+ " \"CCONJ\": 5, \"PRON\": 6, \"ADV\": 7, \"DET\": 8,\n",
+ " \"NUM\": 9, \"PROPN\": 1, \"SCONJ\": 5, \"X\": 10,\n",
+ " \"AUX\": 10, \"PUNCT\": 11}\n",
+ "\n",
+ "max_code_len = 52\n",
+ "\n",
+ "\n",
+ "def encode_def_to_classifier_input(sentence, nlp_model):\n",
+ " doc = nlp_model(sentence)\n",
+ " curr_arr = []\n",
+ "\n",
+ " for i in range(len(doc)):\n",
+ " word = doc[i]\n",
+ " pos_tag = word.pos_\n",
+ " if pos_tag not in numb:\n",
+ " curr_arr.append(numb[\"X\"])\n",
+ " continue\n",
+ " curr_arr.append(numb[pos_tag])\n",
+ "\n",
+ " null_prefix = [0] * (max_code_len - len(curr_arr))\n",
+ "\n",
+ " return np.array([null_prefix + curr_arr])"
+ ],
+ "execution_count": 6,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "9-YbG60TReYB",
+ "outputId": "26459c61-5fc6-417f-b4fa-190b14620a40"
+ },
+ "source": [
+ "question = \"мать, дочь, бабушка\"\n",
+ "encode_def = encode_def_to_classifier_input(question, nlp)\n",
+ "type_arr = np.argmax(classifier.predict_on_batch(encode_def))\n",
+ "print(type_arr)"
+ ],
+ "execution_count": 7,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "1\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "2Y47hhctdgbD",
+ "outputId": "f61894d7-51c0-47b5-b0dd-8503d4035ea7"
+ },
+ "source": [
+ "import nltk\n",
+ "\n",
+ "nltk.download('stopwords')\n",
+ "nltk.download('punkt')"
+ ],
+ "execution_count": 8,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "[nltk_data] Downloading package stopwords to /root/nltk_data...\n",
+ "[nltk_data] Package stopwords is already up-to-date!\n",
+ "[nltk_data] Downloading package punkt to /root/nltk_data...\n",
+ "[nltk_data] Package punkt is already up-to-date!\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 8
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "nShFZe6DUzfH"
+ },
+ "source": [
+ "import re\n",
+ "\n",
+ "from nltk.tokenize import word_tokenize\n",
+ "from nltk.corpus import stopwords\n",
+ "\n",
+ "\n",
+ "class SumWords:\n",
+ " '''\n",
+ " Searching similar words for sentence words sum\n",
+ " '''\n",
+ "\n",
+ " def __init__(self, a_model, a_tops=10):\n",
+ " self.tops = a_tops\n",
+ " self.r = re.compile(\"[а-яА-Я-]+\")\n",
+ " self.stops = stopwords.words(\"russian\")\n",
+ " self.model = a_model\n",
+ " self.vocab = self.model.wv.vocab\n",
+ "\n",
+ " def is_in_vocab(self, word):\n",
+ " return word in self.vocab\n",
+ "\n",
+ " def get_words(self, sentence, prefix):\n",
+ " '''\n",
+ " Returns a list of similar words with fixed prefix.\n",
+ " List is ranged by similarity level.\n",
+ " '''\n",
+ "\n",
+ " # prepare the list of words from the sentence\n",
+ " words = [word.lower() for word in word_tokenize(sentence) if word.isalpha()]\n",
+ " words = [w for w in filter(self.r.match, words)]\n",
+ " words = [word for word in words if word not in self.stops]\n",
+ "\n",
+ " text = ' '.join(words)\n",
+ " doc = nlp(text)\n",
+ " words = [word.lemma_ for word in doc]\n",
+ "\n",
+ " words = [word for word in words if word in self.vocab]\n",
+ " if words != []:\n",
+ " sum_similar = self.model.most_similar(positive=words, topn=self.tops)\n",
+ " res = [i[0] for i in sum_similar if prefix == i[0][:len(prefix)]]\n",
+ " else:\n",
+ " return []\n",
+ " \n",
+ " return res"
+ ],
+ "execution_count": 9,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "c7ct8T_kpyLv"
+ },
+ "source": [
+ "from gensim.models import Word2Vec\n",
+ "\n",
+ "word2vec_mod = Word2Vec.load('word2vec.model')\n",
+ "sum_words = SumWords(word2vec_mod, 20)"
+ ],
+ "execution_count": 10,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "n2FNSKCLTYsp",
+ "outputId": "4c6124ea-b2d4-4077-af4f-c73044d70efd"
+ },
+ "source": [
+ "question = \"верхняя часть лампы\"\n",
+ "lst = sum_words.get_words(question, 'а')\n",
+ "print(lst)"
+ ],
+ "execution_count": 11,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "['абажур', 'антаблемент']\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:39: DeprecationWarning: Call to deprecated `most_similar` (Method will be removed in 4.0.0, use self.wv.most_similar() instead).\n"
+ ],
+ "name": "stderr"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "fUj7SCUrPM-s",
+ "outputId": "82cb9aec-aa4a-4dd9-b15d-268d4b19fdda"
+ },
+ "source": [
+ "def get_score_0(prefix_len):\n",
+ " n_test, n_success = 0, 0\n",
+ "\n",
+ " for ind in data.index:\n",
+ " n_test += 1\n",
+ "\n",
+ " word, curr_def = data['word'][ind], data['defs'][ind]\n",
+ "\n",
+ " prefix = word[:min(prefix_len, len(word))]\n",
+ " ans_list = sum_words.get_words(curr_def, prefix)\n",
+ " if len(ans_list):\n",
+ " answer = ans_list[0]\n",
+ " else:\n",
+ " answer = ''\n",
+ "\n",
+ " if answer == word:\n",
+ " n_success += 1\n",
+ "\n",
+ " return n_success / n_test * 100\n",
+ "\n",
+ "\n",
+ "results = {}\n",
+ "for i in range(0, 7):\n",
+ " score = get_score_0(i)\n",
+ " results[i] = score\n",
+ "\n",
+ "print(results)"
+ ],
+ "execution_count": 12,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:39: DeprecationWarning: Call to deprecated `most_similar` (Method will be removed in 4.0.0, use self.wv.most_similar() instead).\n"
+ ],
+ "name": "stderr"
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "{0: 17.478411053540587, 1: 31.537132987910187, 2: 35.682210708117445, 3: 37.167530224525045, 4: 37.54749568221071, 5: 37.82383419689119, 6: 38.20379965457686}\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "cueWUf0hQvG1"
+ },
+ "source": [
+ "import gensim\n",
+ "from pymystem3 import Mystem\n",
+ "import re\n",
+ "\n",
+ "\n",
+ "class Word2vec:\n",
+ " def get_word_vector(self, word):\n",
+ " raise Exception('Not implemented')\n",
+ "\n",
+ "\n",
+ "class FastText(Word2vec):\n",
+ " def __init__(self, path):\n",
+ " self.model = gensim.models.Word2Vec.load(path)\n",
+ " \n",
+ " def get_word_vector(self, word):\n",
+ " return self.model[word] \n",
+ "\n",
+ "\n",
+ "class Text2Lemms:\n",
+ " def __init__(self):\n",
+ " self.mystem = Mystem()\n",
+ " \n",
+ " def get_lemms(self, text, tag=None):\n",
+ " list_lemm = []\n",
+ " for lemma in self.mystem.analyze(text):\n",
+ " if 'analysis' in lemma and len(lemma['analysis']):\n",
+ " analysis = lemma['analysis'][0]\n",
+ " if analysis.get('qual', None) == 'bastard':\n",
+ " continue\n",
+ " pos_tag = re.match('[A-Z]+', analysis['gr']).group(0)\n",
+ " if tag and pos_tag==tag:\n",
+ " list_lemm.append(analysis['lex'])\n",
+ " elif not tag:\n",
+ " list_lemm.append({'lex': analysis['lex'], 'pos': pos_tag})\n",
+ " return list_lemm\n",
+ "\n",
+ "\n",
+ "class WordTrie:\n",
+ " '''\n",
+ " How use \n",
+ "\n",
+ " wt = WordTrie(FastText())\n",
+ " wt.build_dict(['word', 'world', 'cat', 'cats'])\n",
+ " for word, vector in wt.search_by_prefix('cats'): \n",
+ " print(word)\n",
+ " '''\n",
+ " def __init__(self, word2vec:Word2vec):\n",
+ " self.root = _Node('*')\n",
+ " self.get_vector = word2vec.get_word_vector\n",
+ "\n",
+ " def add(self, word):\n",
+ " tmp_node = self.root\n",
+ "\n",
+ " for char in word:\n",
+ " child = tmp_node.children.get(char, None)\n",
+ " if child is None:\n",
+ " child = _Node(char)\n",
+ " tmp_node.children[char] = child\n",
+ " child.parent = tmp_node\n",
+ " tmp_node = child\n",
+ "\n",
+ " tmp_node.value = self.get_vector(word)\n",
+ "\n",
+ " def build_dict(self, words):\n",
+ " for word in words:\n",
+ " self.add(word)\n",
+ " return self\n",
+ "\n",
+ " def search_by_prefix(self, prefix):\n",
+ " tmp_node = self.root\n",
+ "\n",
+ " for char in prefix:\n",
+ " tmp_node = tmp_node.children.get(char, None)\n",
+ " if tmp_node is None:\n",
+ " return\n",
+ " yield from tmp_node.get_childs()\n",
+ "\n",
+ "\n",
+ "#####################\n",
+ "##### for WordTrie\n",
+ "#####################\n",
+ "\n",
+ "\n",
+ "class _Node:\n",
+ " def __init__(self, char: str):\n",
+ " self.char = char\n",
+ " self.children = {}\n",
+ " self.value = None\n",
+ " self.parent = None\n",
+ "\n",
+ " def _get_prefix(self):\n",
+ " tmp_node = self\n",
+ " prefix = ''\n",
+ " while tmp_node.parent:\n",
+ " prefix = tmp_node.char + prefix\n",
+ " tmp_node = tmp_node.parent\n",
+ "\n",
+ " return prefix\n",
+ "\n",
+ " def get_childs(self):\n",
+ " stack = [(self, self._get_prefix())]\n",
+ "\n",
+ " while stack:\n",
+ " tmp_node, prefix = stack.pop(0)\n",
+ " if tmp_node.value is not None:\n",
+ " yield prefix, tmp_node.value\n",
+ " \n",
+ " for char, child in tmp_node.children.items():\n",
+ " stack.append((child, prefix+char))\n"
+ ],
+ "execution_count": 13,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "0xAyngNvQVzw"
+ },
+ "source": [
+ "import pickle\n",
+ "\n",
+ "with open('tree.pickle', 'rb') as f:\n",
+ " trie = pickle.load(f)"
+ ],
+ "execution_count": 14,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "9Y9Dr4FKVeX3",
+ "outputId": "fbe2727d-416d-4b11-dcb5-f07601683bd1"
+ },
+ "source": [
+ "!pip install wiki-ru-wordnet"
+ ],
+ "execution_count": 15,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Requirement already satisfied: wiki-ru-wordnet in /usr/local/lib/python3.7/dist-packages (1.0.3)\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "lvwI57poVRsn"
+ },
+ "source": [
+ "from wiki_ru_wordnet import WikiWordnet"
+ ],
+ "execution_count": 16,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "pvW8Z4xPVzsi"
+ },
+ "source": [
+ "import pymorphy2\n",
+ "\n",
+ "\n",
+ "morph = pymorphy2.MorphAnalyzer()\n",
+ "\n",
+ "\n",
+ "def search_simple_bigramm(lemma_list):\n",
+ " if not lemma_list:\n",
+ " return\n",
+ "\n",
+ " for w1, w2 in zip(lemma_list, lemma_list[1:]):\n",
+ " if w1['pos']=='A' and w2['pos']=='S':\n",
+ " # parsing word - receiving list of possible values\n",
+ " w1_vars = morph.parse(w1['lex'])\n",
+ " # choosing elements with compatible POS\n",
+ " w1_vars = [i for i in w1_vars if i.tag.POS == 'ADJF']\n",
+ "\n",
+ " word2 = w2['lex']\n",
+ " w2_vars = morph.parse(word2)\n",
+ " w2_vars = [i for i in w2_vars if i.tag.POS == 'NOUN']\n",
+ "\n",
+ " if len(w1_vars) == 0 or len(w2_vars) == 0:\n",
+ " return\n",
+ " gender = w2_vars[0].tag.gender\n",
+ "\n",
+ " try:\n",
+ " word1 = w1_vars[0].inflect({gender}).word\n",
+ " except ValueError:\n",
+ " word1 = w1_vars[0].word\n",
+ " return word1, word2\n"
+ ],
+ "execution_count": 17,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "0hL5HIhkUloX"
+ },
+ "source": [
+ "class HypWords:\n",
+ " # hyponym and hypernym words\n",
+ " def __init__(self, prefix_trie=None):\n",
+ " self.wikiwordnet = WikiWordnet()\n",
+ " #self.prefix_trie = get_prefix_trie() # перенести потом в бота, инициализировать перед запуском, как и остальные модели\n",
+ " self.prefix_trie = prefix_trie\n",
+ " self.text2LemmsModel = Text2Lemms()\n",
+ "\n",
+ " def _get_hyp_with_prefix(self, word, words_with_prefix):\n",
+ " hyp_w = get_hyponym_and_hypernym(self.wikiwordnet, word)\n",
+ " answer = hyp_w & words_with_prefix\n",
+ " if answer:\n",
+ " return list(answer)\n",
+ " return []\n",
+ "\n",
+ " def is_in_vocab(self, word):\n",
+ " sets = self.wikiwordnet.get_synsets(word)\n",
+ " return len(sets) != 0\n",
+ "\n",
+ " def get_words(self, sentence, prefix):\n",
+ " list_lex = self.text2LemmsModel.get_lemms(sentence)\n",
+ " bigramm_w = search_simple_bigramm(list_lex)\n",
+ " words_with_prefix = set(w[0] for w in self.prefix_trie.search_by_prefix(prefix))\n",
+ " word = None\n",
+ "\n",
+ " if bigramm_w:\n",
+ " bigramm = ' '.join(bigramm_w)\n",
+ " ans = self._get_hyp_with_prefix(bigramm, words_with_prefix)\n",
+ " if ans:\n",
+ " return ans\n",
+ " word = bigramm_w[1]\n",
+ " else:\n",
+ " for w in list_lex:\n",
+ " if w['pos'] == 'S':\n",
+ " word = w['lex']\n",
+ " break\n",
+ "\n",
+ " return self._get_hyp_with_prefix(word, words_with_prefix)\n",
+ "\n",
+ "\n",
+ "def get_hyponym_and_hypernym(wikiwordnet, word):\n",
+ " synsets = wikiwordnet.get_synsets(word)\n",
+ " set_words = set()\n",
+ "\n",
+ " if not synsets:\n",
+ " return set_words\n",
+ "\n",
+ " synset1 = synsets[0]\n",
+ "\n",
+ " for hyponym in wikiwordnet.get_hyponyms(synset1):\n",
+ " set_words |= { w.lemma() for w in hyponym.get_words()}\n",
+ "\n",
+ " for hypernym in wikiwordnet.get_hypernyms(synset1):\n",
+ " set_words |= { w.lemma() for w in hypernym.get_words()}\n",
+ "\n",
+ " return set_words\n"
+ ],
+ "execution_count": 18,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Tbx59RbOV5Gs"
+ },
+ "source": [
+ "hyp_words = HypWords(prefix_trie=trie)"
+ ],
+ "execution_count": 19,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "wi7COOMY9GPI",
+ "outputId": "a87723b0-cfed-4d5f-e55d-84547466a1bf"
+ },
+ "source": [
+ "question = \"лампа\"\n",
+ "lst = hyp_words.get_words(question, 'л')\n",
+ "print(lst)"
+ ],
+ "execution_count": 20,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "['лампада']\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "lP0Amz0bWoR2"
+ },
+ "source": [
+ "LIST_MODELS = [sum_words, hyp_words]"
+ ],
+ "execution_count": 21,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "3VNMdun5IpwY"
+ },
+ "source": [
+ "def get_score(prefix_len):\n",
+ " n_test, n_success = 0, 0\n",
+ "\n",
+ " for ind in data.index:\n",
+ " n_test += 1\n",
+ "\n",
+ " word, curr_def = data['word'][ind], data['defs'][ind]\n",
+ "\n",
+ " encode_def = encode_def_to_classifier_input(curr_def, nlp)\n",
+ " def_type = np.argmax(classifier.predict_on_batch(encode_def))\n",
+ "\n",
+ " list_words = []\n",
+ " prefix = word[:min(prefix_len, len(word))]\n",
+ " for model in LIST_MODELS:\n",
+ " ans_words = model.get_words(question, prefix)\n",
+ " if len(ans_words):\n",
+ " list_words.append(ans_words[0])\n",
+ "\n",
+ " index = 1 if def_type == 1 and len(list_words) == 2 else 0\n",
+ " if index < len(list_words):\n",
+ " answer = list_words[index]\n",
+ " elif len(list_words):\n",
+ " answer = list_words[0]\n",
+ " else:\n",
+ " answer = \"\"\n",
+ "\n",
+ " if answer == word:\n",
+ " n_success += 1\n",
+ "\n",
+ " return n_success / n_test * 100"
+ ],
+ "execution_count": 22,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "pq5glqGRYvIF",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "974b6fb4-6429-4743-f27d-ac1c6d86b1c1"
+ },
+ "source": [
+ "results = {}\n",
+ "for i in range(0, 7):\n",
+ " score = get_score(i)\n",
+ " results[i] = score\n",
+ "\n",
+ "print(results)"
+ ],
+ "execution_count": 23,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:39: DeprecationWarning: Call to deprecated `most_similar` (Method will be removed in 4.0.0, use self.wv.most_similar() instead).\n"
+ ],
+ "name": "stderr"
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "{0: 0.03454231433506045, 1: 0.0690846286701209, 2: 0.17271157167530224, 3: 0.17271157167530224, 4: 0.17271157167530224, 5: 0.17271157167530224, 6: 0.17271157167530224}\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ }
+ ]
+}
\ No newline at end of file
diff --git a/test/Tests_results.ipynb b/test/Tests_results.ipynb
index 6654a2e..cd5f3cc 100644
--- a/test/Tests_results.ipynb
+++ b/test/Tests_results.ipynb
@@ -15,6 +15,125 @@
"*Новые результаты предлагается добавлять на верх страницы, чтобы избежать проматывания при их поиске*"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 28.06.2021"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Тестовый датасет\n",
+ "\n",
+ "data/test_sets/004_human_set.csv"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Модели поиска ответа:\n",
+ "\n",
+ "* sum_model - /src/models/sum_model.py. Word2vec model - обучена на смешанном датасете определений (формальных с википедии и пользовательских)\n",
+ " * лемматизация слов\n",
+ " * сложение векторов слов\n",
+ " * поиск ближайшего среди слов с соответствующим префиксом\n",
+ " \n",
+ " \n",
+ "* ens_sum_hyp - объединение двух моделей суммирования и гипонимов/гиперонимов (при построении префиксного дерева был использован обученный на смешанном датасете Word2vec). При обработке определения решается, ответ какой из моделей будет передан на выход. Выбор на основе классификатора определений (https://github.com/Alioth6/contact_game_bot/blob/master/src/models/Classification%20with%20LSTM.ipynb)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Метрика\n",
+ "\n",
+ "Проверяется точное соответствие ответа загаданному слову при нескольких первых известных буквах в загаданном слове (от 0 до 4).\n",
+ "\n",
+ "Финальный скор - процент слов, для которых проверка прошла успешно."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Таблица результатов\n",
+ "\n",
+ "Модель | Знаем 0 букв; score, % | 1 буква; score, % | 2 буквы; score, % | 3 буквы; score, % | 4 буквы score, %\n",
+ "------------- | ----------- | ----------- | ----------- | ----------- | -----\n",
+ "sum_model | 17.48 | 31.54 | 35.68 | 37.17 | 37.55\n",
+ "ens_sum_hyp | 0.03 | 0.07 | 0.17 | 0.17 | 0.17"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 18.12.2020"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Тестовый датасет\n",
+ "\n",
+ "data/test_sets/004_human_set.csv - несловарный датасет, содержит около 3000 определений. Датасет собран осенью 2020 при работе бота на платформе Heroku."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Модели поиска ответа:\n",
+ "\n",
+ "* hyp_model - /src/models/hyp_model.py. На базе семантической сети wikiwordnet.\n",
+ " * лемматизация слов\n",
+ " * поиск базовой биграммы фразы\n",
+ " * поиск гипонимов и гиперонимов для слов базовой биграммы\n",
+ "\n",
+ "\n",
+ "* sum_model - /src/models/sum_model.py. Word2vec model - сжатая версия fasttext модели от RusVectores [ft_freqprune_100K_20K_pq_300.bin](https://github.com/avidale/compress-fasttext/releases/tag/v0.0.1).\n",
+ " * лемматизация слов\n",
+ " * сложение векторов слов\n",
+ " * поиск ближайшего среди слов с соответствующим префиксом\n",
+ " \n",
+ " \n",
+ "* sum_plus_hyp_model - объединение двух упомянутых моделей. На выход передаётся до двух слов - результаты от каждой модели. Таким образом мы проверяем, могут ли модели дополнять друг друга - распознают ли они разные определения.\n",
+ "\n",
+ "\n",
+ "* ens_sum_hyp - объединение двух упомянутых моделей. При обработке определения решается, ответ какой из моделей будет передан на выход. Выбор по числу существительных и глаголов в определении."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Метрика\n",
+ "\n",
+ "Проверяется точное соответствие ответа загаданному слову (в случае модели sum_plus_hyp_model - наличие загаданного слова в выходном списке).\n",
+ "\n",
+ "Финальный скор - процент слов, для которых проверка прошла успешно."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Таблица результатов\n",
+ "\n",
+ "Модель | it/s | Score, %\n",
+ "------------- | ----------- | -----\n",
+ "hyp_model | 86.63 | 7,9\n",
+ "sum_model | 1.72 | 16.65\n",
+ "sum_plus_hyp_model | 1.79 | 22\n",
+ "ens_sum_hyp | 1.38 | 17.6\n"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {},
@@ -176,7 +295,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.8.2"
+ "version": "3.7.3"
}
},
"nbformat": 4,