Created
November 17, 2022 22:45
-
-
Save zhaoyanpeng/affab669325e0d701db0fc35f5d04381 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "id": "897059d7", | |
| "metadata": {}, | |
| "source": [ | |
| "## Env" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "c68042e9", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import json\n", | |
| "import glob\n", | |
| "import copy\n", | |
| "import subprocess\n", | |
| "import numpy as np\n", | |
| "from tqdm import tqdm\n", | |
| "from nltk import word_tokenize\n", | |
| "from collections import Counter, defaultdict\n", | |
| "\n", | |
| "import re\n", | |
| "import io\n", | |
| "import os, sys\n", | |
| "import requests\n", | |
| "import PIL\n", | |
| "from PIL import Image, ImageDraw, ImageFont\n", | |
| "import numpy as np\n", | |
| "\n", | |
| "import random\n", | |
| "from os import path\n", | |
| "from datetime import datetime\n", | |
| "from dataclasses import dataclass\n", | |
| "import nltk\n", | |
| "\n", | |
| "import yaml\n", | |
| "import torch\n", | |
| "from omegaconf import OmegaConf\n", | |
| "from hydra import initialize, initialize_config_module, initialize_config_dir, compose\n", | |
| "\n", | |
| "import pandas as pd\n", | |
| "from functools import partial\n", | |
| "from datasets import load_dataset\n", | |
| "from braceexpand import braceexpand" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "8add443d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "vreason_path = \"/mnt/yann/code/vreason\"\n", | |
| "paths = [vreason_path]\n", | |
| "for ppath in paths:\n", | |
| " if ppath not in sys.path:\n", | |
| " sys.path.insert(0, ppath)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "265999ce", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "home = \"/mnt/zhaoyanpeng\"\n", | |
| "repo_root = f\"{home}/code/clevr-ed/caption\"\n", | |
| "dalle_droot = f\"{home}/data/dallemini\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "d29cfd0b", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "%load_ext autoreload\n", | |
| "%autoreload 2" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "eec4580b", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# %reload_ext autoreload" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "a865a7b8", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import vreason\n", | |
| "from vreason.module import PretrainedVQGAN\n", | |
| "from vreason.data import build_clevr_image_text_data\n", | |
| "from vreason.monitor import DalleMonitor\n", | |
| "from vreason.module import ENCODER_HEADS_REGISTRY\n", | |
| "from vreason.module import DECODER_HEADS_REGISTRY\n", | |
| "from vreason.model import build_main_model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "4a40a022", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from matplotlib import pyplot as plt\n", | |
| "from matplotlib import image as ireader\n", | |
| "import matplotlib.patches as patches\n", | |
| "from matplotlib.patches import Rectangle\n", | |
| "import matplotlib.gridspec as gridspec\n", | |
| "plt.rcParams[\"savefig.bbox\"] = 'tight'" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "3811e008", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# %reload_ext autoreload\n", | |
| "from ipynb.fs.full.ipcfg_analysis_parse_functions import (\n", | |
| " convex_hull,\n", | |
| " yxyx_to_xyxy,\n", | |
| " bias_by_width,\n", | |
| " estimate_overlap,\n", | |
| " show_image_bboxes,\n", | |
| " visualize_2d_parse,\n", | |
| " visualize_partial_parse_toy,\n", | |
| " visualize_partial_parse_clevr,\n", | |
| " obj_name_and_box_from_annotations,\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "4f4326e8", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# from vreason.module.decoder.ipcfg_head import IPCFG2DDecHead, IPCFG1DDecHead\n", | |
| "# DECODER_HEADS_REGISTRY._obj_map.pop(\"IPCFG2DDecHead\", \"\")\n", | |
| "# DECODER_HEADS_REGISTRY.register(IPCFG2DDecHead)\n", | |
| "# DECODER_HEADS_REGISTRY._obj_map.pop(\"IPCFG1DDecHead\")\n", | |
| "# DECODER_HEADS_REGISTRY.register(IPCFG1DDecHead)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "2d9bdf03", | |
| "metadata": {}, | |
| "source": [ | |
| "## Conf for loading data" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "9e5e66a4", | |
| "metadata": { | |
| "code_folding": [ | |
| 7, | |
| 16, | |
| 49 | |
| ] | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "alias_root=\"/mnt/yann/model/ipcfg\"\n", | |
| "model_root=\"/mnt/yann/model/ipcfg\"\n", | |
| "\n", | |
| "# pre-trained VQGAN\n", | |
| "vq_root = \"/mnt/yann/model/vqgan_tuned/logs\"\n", | |
| "\n", | |
| "# 256 x 16384\n", | |
| "class VQConf:\n", | |
| " vq_time = \"2022-09-10T00-25-13\"\n", | |
| " vq_name = \"clevr_vqgan_128e_16384\"\n", | |
| " vq_file = \"last.ckpt\"\n", | |
| " vocab_size = 16384\n", | |
| " max_vis_len = 256\n", | |
| " gumbel = False\n", | |
| "\n", | |
| "# 256 x 1024\n", | |
| "class VQConf:\n", | |
| " vq_root = vq_root\n", | |
| " vq_time = \"2022-09-10T05-13-14\"\n", | |
| " vq_name = \"clevr_vqgan_128e_1024\"\n", | |
| " vq_file = \"last.ckpt\"\n", | |
| " vocab_size = 1024\n", | |
| " max_vis_len = 256\n", | |
| " gumbel = False\n", | |
| " \n", | |
| "vq = VQConf()\n", | |
| "\n", | |
| "text_fold = \"CAPTION_4.0_25_captions\"\n", | |
| "data_root = f\"/mnt/yann/data/CLEVR_text/{text_fold}\"\n", | |
| "more_root = f\"/mnt/yann/data/CLEVR_v1.0_320x320_bbox\"\n", | |
| "\n", | |
| "model_sign = vq.vq_file.split(\".\", 1)[0]\n", | |
| "data_root = f\"/mnt/yann/data/dallemini/{vq.vq_time}_{vq.vq_name}/{model_sign}/{text_fold}\"\n", | |
| "\n", | |
| "data_name = \"train/\\{001..007\\}.parquet\"\n", | |
| "eval_name = \"val/\\{001..002\\}.parquet\"\n", | |
| "\n", | |
| "epoch = 1\n", | |
| "bsize = 2\n", | |
| "\n", | |
| "min_obj_num = 1\n", | |
| "max_obj_num = 1\n", | |
| "\n", | |
| "model_name = \"ipcfg_test\"\n", | |
| "model_file = \"00002560.pth\"\n", | |
| "\n", | |
| "alias_name = \"ipcfg_eval_toy\"\n", | |
| "alias_odir = \"toy_5x5\"\n", | |
| "\n", | |
| "class DataConf:\n", | |
| " data_root = data_root\n", | |
| " more_root = more_root\n", | |
| " data_name = data_name \n", | |
| " eval_name = eval_name\n", | |
| " txt_vocab_name = \"txt_vocab_is_null\"\n", | |
| " min_obj_num = min_obj_num\n", | |
| " max_obj_num = max_obj_num\n", | |
| " min_train_vid = 0\n", | |
| " max_train_vid=9999\n", | |
| " min_eval_vid=0\n", | |
| " max_eval_vid=999\n", | |
| " train_samples = None \n", | |
| " eval_samples = \"[0,128]\" \n", | |
| " test_samples = None\n", | |
| " batch_size = bsize\n", | |
| " vis_vocab_size=3\n", | |
| " obj_fake_hw = \"[[3,1],[1,4]]\"\n", | |
| "\n", | |
| "dt = DataConf()\n", | |
| "\n", | |
| "verbose = True\n", | |
| "training = False\n", | |
| "worker = \"IPCFG\"\n", | |
| "Monitor = \"DalleMonitor\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "id": "7a64ac25", | |
| "metadata": { | |
| "code_folding": [ | |
| 0, | |
| 11, | |
| 32 | |
| ] | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "base_overrides = [\n", | |
| " \"+model/vq=default\",\n", | |
| " \"+model/embedder=dummy\",\n", | |
| " \"+model/encoder=dummy\",\n", | |
| " \"+model/decoder=ipcfg\",\n", | |
| " \"+model/loss=ipcfg\",\n", | |
| "# \"+optimizer=dummy\",\n", | |
| " \"+data=clevr_dalle\",\n", | |
| " \"+running=default\"\n", | |
| "]\n", | |
| "\n", | |
| "special_overrides = [\n", | |
| " f\"model.vq.gumbel={vq.gumbel}\",\n", | |
| " f\"model.vq.model_root={vq.vq_root}\",\n", | |
| " f\"model.vq.model_name={vq.vq_name}\",\n", | |
| " f\"model.vq.model_file={vq.vq_file}\",\n", | |
| " f\"model.vq.model_time={vq.vq_time}\",\n", | |
| " f\"model.vq.vocab_size={vq.vocab_size}\",\n", | |
| " f\"model.vq.max_vis_len={vq.max_vis_len}\",\n", | |
| " f\"model.vq.load=false\",\n", | |
| " \n", | |
| " f\"+model.encoder.z_dim=0\",\n", | |
| " \n", | |
| " f\"model.decoder.NT=8\",\n", | |
| " f\"model.decoder.T=3\",\n", | |
| " f\"model.decoder.s_dim=256\",\n", | |
| " \n", | |
| " \n", | |
| " f\"running.epochs={epoch}\",\n", | |
| " f\"running.batch_size={bsize}\",\n", | |
| "]\n", | |
| "\n", | |
| "extra_overrides = [\n", | |
| " f\"port=9990\",\n", | |
| " f\"num_gpus=1\",\n", | |
| " f\"verbose={verbose}\",\n", | |
| " f\"eval={not training}\",\n", | |
| " f\"mode=dp\",\n", | |
| " f\"num_proc=0\",\n", | |
| " f\"seed=1213\",\n", | |
| " f\"rank=0\",\n", | |
| " f\"autocast=false\",\n", | |
| " f\"worker={worker}\",\n", | |
| " f\"alias_odir={alias_odir}\",\n", | |
| " f\"alias_root={alias_root}\",\n", | |
| " f\"alias_name={alias_name}\",\n", | |
| " f\"model_root={model_root}\",\n", | |
| " f\"model_name={model_name}\",\n", | |
| " f\"model_file={model_file}\",\n", | |
| " f\"check_unused_param=false\",\n", | |
| " f\"data.data_root={dt.data_root}\",\n", | |
| " f\"data.more_root={dt.more_root}\",\n", | |
| " f\"data.data_name={dt.data_name}\",\n", | |
| " f\"data.eval_name={dt.eval_name}\",\n", | |
| " f\"data.train_samples={dt.train_samples}\",\n", | |
| " f\"data.eval_samples={dt.eval_samples}\",\n", | |
| " f\"data.test_samples={dt.test_samples}\",\n", | |
| " f\"data.batch_size={dt.batch_size}\",\n", | |
| " \n", | |
| " f\"data.txt_vocab_name={dt.txt_vocab_name}\",\n", | |
| " f\"data.min_obj_num={dt.min_obj_num}\",\n", | |
| " f\"data.max_obj_num={dt.max_obj_num}\",\n", | |
| "# f\"data.={dt.}\",\n", | |
| "# f\"data.={dt.}\",\n", | |
| "# f\"data.={dt.}\",\n", | |
| "# f\"data.={dt.}\",\n", | |
| "# f\"data.={dt.}\",\n", | |
| " \n", | |
| " f\"data.vis_vocab_size={dt.vis_vocab_size}\",\n", | |
| " f\"+data.vis_fake_hw=5\",\n", | |
| " f\"+data.obj_fake_hw={dt.obj_fake_hw}\",\n", | |
| " f\"+data.obj_delt_hw=[]\",\n", | |
| " \n", | |
| " f\"+data.require_txt=false\",\n", | |
| " f\"+data.use_preencoded_pandas=true\",\n", | |
| " f\"+data.txt_special_token={{bos:\\\"<TXT|BOS>\\\",pad:\\\"<PAD>\\\",unk:\\\"<UNK>\\\"}}\",\n", | |
| " f\"+data.vis_special_token={{bos:\\\"<VIS|BOS>\\\",unk:\\\"<UNK>\\\"}}\",\n", | |
| "]\n", | |
| "overrides = base_overrides + extra_overrides + special_overrides" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "id": "7fcc5316", | |
| "metadata": { | |
| "code_folding": [ | |
| 1 | |
| ] | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "abs_config_dir=os.path.abspath(f\"{vreason_path}/configs\")\n", | |
| "with initialize_config_dir(config_dir=abs_config_dir):\n", | |
| " cfg = compose(config_name=\"default\", overrides=overrides)\n", | |
| "# print(OmegaConf.to_yaml(cfg))\n", | |
| "mcfg = cfg.model\n", | |
| "echo = print" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "id": "97de68a4", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# print(OmegaConf.to_yaml(cfg))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "id": "50713902", | |
| "metadata": { | |
| "code_folding": [ | |
| 0 | |
| ] | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def bias_by_width(argmax_list):\n", | |
| " row = col = other = total = nsample = 0\n", | |
| " for pred in argmax_list:\n", | |
| " pred = yxyx_to_xyxy(pred)\n", | |
| " nsample += 1\n", | |
| " total += len(pred)\n", | |
| " for b in pred:\n", | |
| " x1, y1, x2, y2 = b\n", | |
| "# y1, x1, y2, x2 = b\n", | |
| " if x2 - x1 == 1:\n", | |
| " col += 1\n", | |
| " if y2 - y1 == 1:\n", | |
| " row += 1\n", | |
| " if x2 - x1 != 1 and y2 - y1 != 1:\n", | |
| " other += 1\n", | |
| " ratio_row, ratio_col, ratio_other = [x / total for x in [row, col, other]]\n", | |
| " print(\n", | |
| " f\"Biased by \" +\n", | |
| " f\"row {ratio_row:.3f} ({row / nsample:.3f}) \" +\n", | |
| " f\"col {ratio_col:.3f} ({col / nsample:.3f}) \" +\n", | |
| " f\"other {ratio_other:.3f} ({other / nsample:.3f}) # sample {nsample}\"\n", | |
| " )" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "55398be7", | |
| "metadata": {}, | |
| "source": [ | |
| "## Infer" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "ac7bba21", | |
| "metadata": {}, | |
| "source": [ | |
| "### Load data" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "id": "f0f1d19b", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "cfg.data.batch_size = 4\n", | |
| "cfg.data.obj_fake_hw = [[4,1],[1,3]]\n", | |
| "convas_size = (16 * 5 + 1,) * 2\n", | |
| "\n", | |
| "linear = True\n", | |
| "if linear:\n", | |
| " grid_size = [6,6]\n", | |
| " cfg.data.vis_fake_hw=grid_size\n", | |
| " cfg.data.obj_fake_hw=[[4,1],[1,3]] #[[4,1],[1,3]] #\n", | |
| " cfg.data.obj_delt_hw=[]\n", | |
| " cfg.data.eval_samples=1024\n", | |
| " cfg.model.decoder.grid_size=grid_size\n", | |
| " \n", | |
| " convas_size = (16 * grid_size[0] + 1, 16 * grid_size[1] + 1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "id": "500a7b7e", | |
| "metadata": { | |
| "code_folding": [ | |
| 0 | |
| ] | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Err: [Errno 2] No such file or directory: '/mnt/yann/data/dallemini/2022-09-10T05-13-14_clevr_vqgan_128e_1024/last/CAPTION_4.0_25_captions/vis_vocab_is_null'\n", | |
| "Err: [Errno 2] No such file or directory: '/mnt/yann/data/dallemini/2022-09-10T05-13-14_clevr_vqgan_128e_1024/last/CAPTION_4.0_25_captions/txt_vocab_is_null'\n", | |
| "Load 500 (2000) batches (main).\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "dataloader, evalloader, testloader, encoder_vocab, decoder_vocab, cate_vocab = \\\n", | |
| " eval(f\"build_{cfg.data.name.lower()}_image_text_data\")(\n", | |
| " cfg.data, not cfg.eval, echo, ddp_mode=(cfg.mode == \"ddp\")\n", | |
| " )" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "id": "20e1c7ce", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# for ix, x in enumerate(dataloader):\n", | |
| "# # print(x)\n", | |
| "# if ix == 1:\n", | |
| "# break" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "id": "0e46f7e5", | |
| "metadata": { | |
| "code_folding": [ | |
| 0 | |
| ] | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def check_data(dataloader, k=3):\n", | |
| " image = list()\n", | |
| " for ix, x in enumerate(dataloader):\n", | |
| "# print(x[\"image\"])\n", | |
| "# if ix == 1:\n", | |
| "# break\n", | |
| " image.append(x[\"image\"])\n", | |
| " \n", | |
| " image = np.concatenate(image, axis=0)\n", | |
| " return image, np.eye(k)[image]\n", | |
| "image_all, image_onehot = check_data(dataloader)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "id": "0297a25a", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "x1 = image_onehot.mean(0)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "id": "3ac3c67d", | |
| "metadata": { | |
| "code_folding": [] | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# x1" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "id": "b90f6569", | |
| "metadata": { | |
| "code_folding": [] | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# x2 = image_onehot.mean(0)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "id": "83c8ed6c", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# x2" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "bc6355a6", | |
| "metadata": {}, | |
| "source": [ | |
| "### Parse" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 23, | |
| "id": "7d7834ba", | |
| "metadata": { | |
| "code_folding": [ | |
| 0, | |
| 13 | |
| ] | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def make_batch(sample, device):\n", | |
| " image = sample.pop(\"image\") #[\"image\"]\n", | |
| " image = torch.from_numpy(image).to(device)\n", | |
| " if len(image.shape) == 4: # can be pre-encoded image token sequences\n", | |
| " image = image.permute(0, 3, 1, 2)\n", | |
| "\n", | |
| " text = sample.pop(\"text\") #[\"text\"]\n", | |
| " text = torch.from_numpy(text).to(device)\n", | |
| "\n", | |
| " text_mask = text == encoder_vocab.PAD_IDX\n", | |
| "\n", | |
| " return {\"image\": image, \"text\": text, \"text_mask\": text_mask, \"vfile\": sample.pop(\"vfile\")} #[\"vfile\"]}\n", | |
| "\n", | |
| "def infer_batch(model, device, kbatch=4):\n", | |
| " ibatch, kbatch = 0, kbatch\n", | |
| " for ibatch, batch in enumerate(dataloader):\n", | |
| " if ibatch != kbatch:\n", | |
| " continue\n", | |
| " # print(batch, batch[\"image\"].dtype)\n", | |
| " batch_dict = make_batch(batch, device)\n", | |
| " batch_dict.update(batch)\n", | |
| " # print(batch_dict, batch_dict[\"image\"].dtype)\n", | |
| " break\n", | |
| " \n", | |
| " v_seq = batch_dict[\"image\"]\n", | |
| "\n", | |
| " mean = lvar = None\n", | |
| " kwargs = {\n", | |
| " \"infer\": True, \"auto_infer\": False, \"exclude_trivial\": True,\n", | |
| " \"require_marginal\": False, \"marginal_as_dict\": True, \"mbr\": False,\n", | |
| " } \n", | |
| " \n", | |
| " with torch.no_grad():\n", | |
| " ll, kl, targets, argmax, marginal, dec_extra = model.decoder_head(\n", | |
| " None, None, v_seq=v_seq, mean=mean, lvar=lvar, **kwargs\n", | |
| " )\n", | |
| " \n", | |
| " return batch_dict, argmax, model.decoder_head.pcfgs" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "75b96a96", | |
| "metadata": { | |
| "heading_collapsed": true | |
| }, | |
| "source": [ | |
| "### Load model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "256c0628", | |
| "metadata": {}, | |
| "source": [ | |
| "### Model A" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 238, | |
| "id": "d57a4172", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# cfg.model_name = \"ipcfg_test_662341_12_s8765_1585_50_85_3_1\"\n", | |
| "# cfg.model_name = \"ipcfg_test_661341_12_s8765_1585_50_85_3_1\"\n", | |
| "\n", | |
| "\n", | |
| "cfg.model_name = \"ipcfg_test_664113_12_s8765_1585_50_85_5_1\" # best\n", | |
| "# cfg.model_name = \"ipcfg_test_664123_12_s8765_1585_50_85_5_1\" # best\n", | |
| "\n", | |
| "# cfg.model_name = \"ipcfg_test_661431_12_s8765_1585_50_85_5_1\"\n", | |
| "# cfg.model_name = \"ipcfg_test_661432_12_s8765_1585_50_85_5_1\"\n", | |
| "\n", | |
| "# cfg.model_name = \"ipcfg_test_662432_12_s8765_1585_50_85_5_1\"\n", | |
| "# cfg.model_name = \"ipcfg_test_663224_12_s8765_1585_50_85_5_1\"\n", | |
| "\n", | |
| "# cfg.model_name = \"ipcfg_test_662342_12_s8765_1585_50_85_5_1\"\n", | |
| "# cfg.model_name = \"ipcfg_test_664223_12_s8765_1585_50_85_5_1\"\n", | |
| "\n", | |
| "\n", | |
| "# cfg.model_name = \"ipcfg_test_664223_12_s8765_1585_50_85_5_1_r\"\n", | |
| "\n", | |
| "# cfg.model_name = \"ipcfg_test_662432_12_s8765_1585_50_85_5_1_r\"\n", | |
| "\n", | |
| "\n", | |
| "# cfg.model_name = \"ipcfg_test_661431_12_s8765_1585_50_85_5_1_r\"\n", | |
| "# cfg.model_name = \"ipcfg_test_664113_12_s8765_1585_50_85_5_1_r\"\n", | |
| "cfg.model_name = \"ipcfg_test_664113_12_s8765_1585_50_85_5_1_r_e50\"\n", | |
| "# cfg.model_name = \"ipcfg_test_661431_12_s8765_1585_50_85_5_1_r_e50\" # best\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| "# cfg.model_name = \"ipcfg_test_664123_12_s8765\"\n", | |
| "\n", | |
| "cfg.model_file = \"last.pth\"\n", | |
| "\n", | |
| "cfg.model_name = \"ipcfg_test_oneset\"\n", | |
| "# cfg.model_file = \"00000640.pth\"\n", | |
| "\n", | |
| "cfg.model.decoder.NT=6\n", | |
| "cfg.model.decoder.n_set=1" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "55c98afa", | |
| "metadata": {}, | |
| "source": [ | |
| "### Model B" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 239, | |
| "id": "cb582574", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "ipcfg_la = False\n", | |
| "if ipcfg_la:\n", | |
| " cfg.model.decoder.name = \"IPCFG2DLADecHead\"\n", | |
| " cfg.model.decoder.nt_la = True\n", | |
| " \n", | |
| " cfg.model_name = \"ipcfg_toyla_664113_s1\"\n", | |
| " \n", | |
| " cfg.model_file = \"00001920.pth\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 240, | |
| "id": "e947a42b", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Loading from /mnt/yann/model/ipcfg/ipcfg_test_oneset/last.pth\n", | |
| "Old configs:\n", | |
| "\n", | |
| "alias_root: /mnt/yann/model/ipcfg\n", | |
| "model_root: /mnt/yann/model/ipcfg\n", | |
| "alias_name: ipcfg_test_oneset\n", | |
| "model_name: test\n", | |
| "alias_odir: outs\n", | |
| "model_file: ''\n", | |
| "model_time: null\n", | |
| "blockprint: false\n", | |
| "purge_root: false\n", | |
| "monitor: DalleMonitor\n", | |
| "worker: IPCFG\n", | |
| "verbose: true\n", | |
| "seed: 9876\n", | |
| "eval: false\n", | |
| "rank: 0\n", | |
| "mode: ddp\n", | |
| "autocast: false\n", | |
| "check_unused_param: false\n", | |
| "num_proc: 1\n", | |
| "num_gpus: 1\n", | |
| "port: 1789\n", | |
| "dist_url: tcp://localhost:${port}\n", | |
| "model:\n", | |
| " vq:\n", | |
| " name: PretrainedVQGAN\n", | |
| " load: false\n", | |
| " gumbel: false\n", | |
| " model_root: /mnt/yann/model/vqgan_tuned/logs\n", | |
| " model_name: clevr_vqgan_128e_1024\n", | |
| " model_file: last.ckpt\n", | |
| " model_time: 2022-09-10T05-13-14\n", | |
| " vocab_size: 1024\n", | |
| " embed_size: 256\n", | |
| " max_vis_len: 256\n", | |
| " embedder:\n", | |
| " name: DummyEncHead\n", | |
| " encoder:\n", | |
| " name: DummyEncHead\n", | |
| " z_dim: 0\n", | |
| " decoder:\n", | |
| " name: IPCFG2DDecHead\n", | |
| " NT: 6\n", | |
| " T: 3\n", | |
| " h_dim: 512\n", | |
| " w_dim: 512\n", | |
| " z_dim: ${..encoder.z_dim}\n", | |
| " s_dim: 256\n", | |
| " n_set: 1\n", | |
| " nt_la: false\n", | |
| " drop_1d: false\n", | |
| " rate_1d: 5.0\n", | |
| " beta_1d: 0.1\n", | |
| " drop_2d: false\n", | |
| " rate_2d: 3.0\n", | |
| " beta_2d: 0.1\n", | |
| " mini_1d_ll: ${..loss.mini_1d_ll}\n", | |
| " mini_1d_2d: ${..loss.mini_1d_2d}\n", | |
| " grid_size:\n", | |
| " - 6\n", | |
| " - 6\n", | |
| " loss:\n", | |
| " name: PCFGLossHead\n", | |
| " layers: []\n", | |
| " scaling: false\n", | |
| " mini_1d_ll: false\n", | |
| " mini_1d_2d: false\n", | |
| " kl_decay_ratio: 1.0\n", | |
| " kl_cycle_steps: 250.0\n", | |
| " kl_activation: cosine\n", | |
| " kl_max_beta: 1.0\n", | |
| " kl_cycle: false\n", | |
| " kl_max: null\n", | |
| " bh_beta: 0\n", | |
| " sh_beta: 1\n", | |
| " th_beta: 0\n", | |
| "optimizer:\n", | |
| " use_lars: false\n", | |
| " name: ${.opt_name}\n", | |
| " warmup: false\n", | |
| " warmup_steps: 250.0\n", | |
| " warmup_times: ${.warmup_steps}\n", | |
| " warmup_epoch: 10\n", | |
| " warmup_fn: linear\n", | |
| " decay_step: 1000.0\n", | |
| " decay_rate: 1.0\n", | |
| " lr: 0.01\n", | |
| " weight_decay: 1.0e-07\n", | |
| " betas:\n", | |
| " - 0.9\n", | |
| " - 0.75\n", | |
| " max_gnorm: 3\n", | |
| " lr_weight: 0.2\n", | |
| " lr_bias: 0.0048\n", | |
| " batch_size: ${running.batch_size}\n", | |
| " epochs: ${running.epochs}\n", | |
| " steps: []\n", | |
| " gamma: 0.5\n", | |
| " batch_sch: true\n", | |
| " opt_name: AdamW\n", | |
| " sch_name: WarmupExpDecayLR\n", | |
| " optimizer:\n", | |
| " - ${optimizer.opt_name}\n", | |
| " - lr: ${optimizer.lr}\n", | |
| " betas: ${optimizer.betas}\n", | |
| " scheduler:\n", | |
| " - ${optimizer.sch_name}\n", | |
| " - step_size: ${optimizer.decay_step}\n", | |
| " gamma: ${optimizer.decay_rate}\n", | |
| " warmup_step: ${optimizer.warmup_steps}\n", | |
| " warmup_fn: ${optimizer.warmup_fn}\n", | |
| "data:\n", | |
| " name: CLEVR\n", | |
| " version: 0.0.1\n", | |
| " num_proc: ${num_proc}\n", | |
| " data_root: /mnt/yann/data/dallemini/2022-09-10T05-13-14_clevr_vqgan_128e_1024/last/CAPTION_4.0_25_captions\n", | |
| " more_root: /mnt/yann/data/CLEVR_v1.0_320x320_bbox\n", | |
| " dump_root: ${.data_root}/${model_time}_{model_name}\n", | |
| " data_name: train/{001..007}.parquet\n", | |
| " eval_name: val/{001..002}.parquet\n", | |
| " test_name: ''\n", | |
| " data_seed: null\n", | |
| " train_samples:\n", | |
| " - 0\n", | |
| " - 1024\n", | |
| " eval_samples:\n", | |
| " - 0\n", | |
| " - 512\n", | |
| " test_samples: null\n", | |
| " batch_size: ${running.batch_size}\n", | |
| " crop_size:\n", | |
| " - 0\n", | |
| " - 0\n", | |
| " - 256\n", | |
| " - 256\n", | |
| " resize_size: 256\n", | |
| " max_obj_num: 1\n", | |
| " min_obj_num: 1\n", | |
| " max_txt_len: 64\n", | |
| " min_txt_num: 2\n", | |
| " max_txt_num: 100\n", | |
| " min_pixel_num: 2\n", | |
| " min_train_vid: 0\n", | |
| " max_train_vid: 9999\n", | |
| " train_divider: 10000\n", | |
| " min_eval_vid: 0\n", | |
| " max_eval_vid: 999\n", | |
| " eval_divider: 2000\n", | |
| " txt_vocab_name: txt_vocab_is_null\n", | |
| " vis_vocab_name: vis_vocab_is_null\n", | |
| " vis_vocab_size: 3\n", | |
| " txt_special_token:\n", | |
| " bos: <TXT|BOS>\n", | |
| " pad: <PAD>\n", | |
| " unk: <UNK>\n", | |
| " vis_special_token:\n", | |
| " bos: <VIS|BOS>\n", | |
| " unk: <UNK>\n", | |
| " txt_only: false\n", | |
| " vis_only: false\n", | |
| " use_preencoded_pandas: true\n", | |
| " require_txt: false\n", | |
| " rot90_image: 0.0\n", | |
| " vis_fake_hw:\n", | |
| " - 6\n", | |
| " - 6\n", | |
| " obj_fake_hw:\n", | |
| " - - 4\n", | |
| " - 1\n", | |
| " - - 1\n", | |
| " - 3\n", | |
| " obj_delt_hw: []\n", | |
| "running:\n", | |
| " peep_rate: 32\n", | |
| " topk_best: 3\n", | |
| " save_rate: 10000000000.0\n", | |
| " skip_save: false\n", | |
| " save_last: false\n", | |
| " nsampling: 1\n", | |
| " infer_mode: false\n", | |
| " batch_size: 8\n", | |
| " epochs: 30\n", | |
| " save_epoch: true\n", | |
| " optim_rate: 1\n", | |
| " v_peep_time: 1\n", | |
| " v_peep_topk: -1\n", | |
| "\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "model = build_main_model(cfg, echo)\n", | |
| "model.build(**{\n", | |
| " \"encoder_vocab\": encoder_vocab,\n", | |
| " \"decoder_vocab\": decoder_vocab,\n", | |
| "})\n", | |
| "\n", | |
| "device = next(model.decoder_head.parameters()).device" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "6e6bbbcc", | |
| "metadata": {}, | |
| "source": [ | |
| "### Plots" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 241, | |
| "id": "e787ddee", | |
| "metadata": { | |
| "code_folding": [ | |
| 0 | |
| ], | |
| "scrolled": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def show_results(vidx=2):\n", | |
| "# vidx = 2\n", | |
| "\n", | |
| " gold = batch_dict[\"bbox\"][vidx]\n", | |
| " pred = argmax[vidx]\n", | |
| "\n", | |
| " pred = yxyx_to_xyxy(pred)\n", | |
| " pred = sorted(pred, key=lambda x: -x[0])\n", | |
| "\n", | |
| " gold_ = gold # None\n", | |
| " pred_ = pred\n", | |
| " pred_ = sorted(pred_, key=lambda x: -(x[2] - x[0]) * (x[3] - x[1]))\n", | |
| "\n", | |
| " visualize_2d_parse(\n", | |
| " [], gold=gold_, area=(0, 0, 16, 16), scale=1.0, figsize=(5, 5),\n", | |
| " rotate=False, new_size=None, convas_size=convas_size\n", | |
| " )\n", | |
| " \n", | |
| " overlap_dict = visualize_partial_parse_toy(gold, pred, area=(0, 0, 16, 16), convas_size=convas_size)\n", | |
| " print()\n", | |
| " for k, v in overlap_dict.items():\n", | |
| " print(k, len(v))\n", | |
| " v = \"\\n\".join(f\"{a}\\t{b:.3f}\" for a, b in v)\n", | |
| " print(v)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 247, | |
| "id": "c16a80fb", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Biased by row 0.171 (6.000) col 0.514 (18.000) other 0.314 (11.000) # sample 4\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "batch_dict, argmax, pcfgs = infer_batch(model, device, kbatch=0)\n", | |
| "bias_by_width(argmax)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 250, | |
| "id": "aa5f9b0f", | |
| "metadata": { | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "(1, 0, 4, 1) 5\n", | |
| "(3, 2, 4, 6) 5\n", | |
| "\n", | |
| "(1, 0, 4, 1) 5\n", | |
| "[1, 0, 2, 2]\t0.400\n", | |
| "[0, 0, 6, 2]\t0.400\n", | |
| "[0, 0, 4, 2]\t0.545\n", | |
| "[2, 0, 4, 2]\t0.571\n", | |
| "[2, 0, 4, 1]\t0.800\n", | |
| "(3, 2, 4, 6) 5\n", | |
| "[0, 2, 5, 6]\t0.333\n", | |
| "[0, 2, 4, 6]\t0.400\n", | |
| "[3, 2, 4, 4]\t0.667\n", | |
| "[3, 4, 4, 6]\t0.667\n", | |
| "[3, 2, 4, 6]\t1.000\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "image/png": "iVBORw0KGgoAAAANSUhEUgAAAR4AAAE2CAYAAAC3Nv2NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAARKElEQVR4nO3cfYwtd13H8c9nHs6tpIWqNTal2PqHxig+xGukNWJQI1IioTESYysNFUhrL/UBlEAjSpSAEtSYUHsbYmiVxOIDNSAgkcQQCeWPXo2JRDFVS1qlVPoAfdI9M/PzjzNzdvbcu3tmd+9+z8zt+5Wc3HN2zs5+Zs7MZ36/2c11SkkAECnbdAAAzz4UD4BwFA+AcBQPgHAUD4BwFA+AcMVeC0+dOjW537XfcMMNm46Akbv99ts3HeFZ4fjx495t2Z7F0+FvfXAu4Xg+WvaufbO0tnhSSkopqWma0X9g/Q2+7bbbJpE5yzJl2WLG2zSNmqbZcKL1ppr5pptukjSNzLaVZdnymK7revTHcpe5e76XQSOepmm0tbU1+g+rKLY3p65rbW1tjf7DKstSZVnKtuq61nw+H3Vm2yqKQrPZTJJUVZWqqhp95rIsl6/n87mqqtpgovVsazabqSgKpZSW+3nMsizTbDZTnudr3zt4qjWF0UM3OrO9fD72zP0y7/bzmPX3aX80PGarx8EUjgtJO/bzlM6/IfitFoBwFA+AcBQPgHAUD4BwFA+AcBQPgHAUD4BwFA+AcBQPgHAUD4BwFA+AcBQPgHAUD4BwFA+AcBQPgHAUD4BwFA+AcBQPgHAUD4BwFA+AcBQPgHAUD4BwFA+AcBQPgHAUD4BwFA+AcBQPgHAUD4BwFA+AcBQPgHAUD4BwFA+AcBQPgHAUD4BwFA+AcBQPgHAUD4BwFA+AcBQPgHDFkDfZVlEUSikddZ5DybJMtpfP8zzfcKL1+hmzLFNRDPpINqrLbHsymbNs+xo7heNC2s5sW3meL4/tseqOhyEGHTF5nivLstEXT/+DOXHihCRNKrM0/rzS9DMXRTH68rG9zGxbZVmOfj/3M68z+FK1n5WOydQyTy2vNL3MUzyWp5h5L4OKJ6Wkuq5H37hZlunkyZOSpKZp1DTNJDJ3V9+maVTX9YYTrdeNgKVpZO6mAF3muq7VNM2GU+1tqpmHTgkHFU/TNNra2hr9SdydELYnk7koimXmuq61tbW16UhrlWWp2WwmSaqqSvP5fMOJ1pvNZsvbBVVVqaqqTUfak+1l5qZpNJ/PR1/w3XE8ZBo7eMTTPcYupSTbk8nczzeFvNJ25qnsY2na+7l7PvbM+8nIr9MBhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeDbM9qYjAOGKIW/KskxlWappmqPOcyh5ni9P5CzLVBSFUkobTrU728qybJk5z3OVZTn6zHmen/Z8Kpkl7Xg+Vt2x0T0vimL0F6ksy5aZ19lX8UxB/yQeuhM2qX8wTSnzjTfeuOkYB3by5EkVxaBDf6O6Y6MrnillXmfwloy9bc+EzNjN1Pbz1PKuM6h4mqZRXdejHk5Li5FZN92aSuY8z5dD/7qu1TTNJDJ3br31VtV1vcE063XTlhMnTkiSqqoa/W2DbnqYZZlSSstjY8z6mdcZXDzz+Xz0G14UxfKeSV3X2traGv1JXJblMnPTNKPPvHrlreta8/l89Jlns9nydZd5zGzr2LFjy+KpqkpVVW061p72c49n8A2FMR9YU9bfryml0e/n1YxTzTwFU8+8l/HfyQRwzqF4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AIQbXDy2jzLHs9bqfp3Cfp5Cxj7bOzJPJf8UMw9VDHlTlmUqikIppaPOcyh5ni8/oDzPVZbl5DJLGnVm28uc0nb+sWfOsu1rbD//WPUzd/t87OWzup/3Mrh4ZrPZoUJFm2LmPM8ncVL0FcWgQ2hUiqKYVG7bKsty0zHOqn3t/TFf1aQzD0fJfPatZp5aXonMR2E/I7JBxdM0jaqqGv2Gd1NC26rrWnVdjz5znufLq++UMneqqlJd1xtMs97q9HAqmYuiUJZlSimpqio1TbPpWHvqZ15nX8Uz9g0vimI5F26aRvP5fPQnsbR9n6Su60lk7l/ZusxjtnolrqpKVVVtKM0w3Q3xrnjquh595izLlo+17x260rGfDFO1ul+nsJ+nkLFvannP5FzYhj7+jgdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igdAOIoHQLjBxWP7KHM8a/X3q+3R7+fVjFPNPHZTzCwNz1kMeVOWZSrLUimlQ4U6almWLTc8z3PNZrPRZ87zfJk5y7LRZ7atLNu+XuV5vsE0w9jekbMoitGfyP39bFtFUezY72O0n4vQ4OIZ+0avmmLmPM8ncSL3kfnodcVzLhm8NWO+Cvf1G5fMR2P1qja1zFPIK00/814GFU/TNJrP56Pf+DzPl8Pouq5VVdUkMpdlKUmqqkp1XY868+q0paoqVVW1wUTrrY4Yuv08Zl3mPM+VUppE5izLBk8JBxfPFE5iafueyZQyd2XZFfyUdAU/dv0r8RQyd/d4plY8Q29xTOsmCIBzAsUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyBcMeRNtmX7qLOcFV3OqWTuZyTz0VjNSOajsZ+Mg4onyzLNZjOllA4V7KhlWbbc8ClmzvNcx44dG3Vm28qy7YFyURSTOCFWM/dfj1E/c5ZlKstSeZ5vONXeVvfzXgaPeIpi0FtHI8uy0R9cq8gcI8/z0Z/Eq6aYeS+D2ySlNOorsXT6UK9pmg2mGaafeQr7WNo51SLz0ekX+hQyn/WpVl3Xqqpq9Cdynucqy1K2Vde15vP56D+soiiW05VuP485s+0dV96qqlTX9egz90fsVVWpqqoNJlrP9nJ6lVLSfD5XXdebjrWnbko4ZAQ8qHhSSqM/IaTFh5VSkm01TTP6E0LaeVI0TTP6E0LSaaPKKWTunwx1XY/+JO4KviueKWROKSnP80HFM63JOYBzAsUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyBcMeRNtpVlmZqmOeo8h2JbtpfPx565y7j6OqW0wVR76+/j7vXUMmdZNvrMWZZN6liWTt/PexlUPFmWaTabjfqDknZ+WHmey/bkMo/9hJC0oyxvvvnmDSbZn24/F0WxYxvGqH9Rsq2yLJXn+YZT7W31QrqXtcXTPymmxPbkMo/9ZDiToVe4MelGPFMypcxDjgnvdXU9derUuC+9AEbr+PHjuzbQnsUDAEdhGmM3AOcUigdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igeTYvuk7bdtOgcOh+I5JNvX2L7X9pO2v2j747Z/sF32dtsf2OX77rf9TPt93eO9tq+w/ZTt88/wPf9o+w22L7edet93v+23HPW2thl23aYIKaUbU0q/tamfj7Nj0P/HgzOz/UZJb5F0o6RPSNqS9DJJr5T06QGreEVK6ZNnWO+Dkn5K0h29r71Q0rdL+lNJF7RfvjClVNn+Pkmfsn0qpfS3B98iIAYjngOy/TxJvynpRErpQymlp1JK85TSR1JKv3rI1d8p6bqVr10n6WMppUdW35xSulfS5yR9z5CV236N7U/bfo/tx2z/p+2ressvsf1h24/avs/269uvv0zSLZJ+uh1p/dOan3O97X+x/YTt/7B9Q2/ZS2w/aPtNth9uR4vXD8h+h+13rKzjzb11XG375bb/rc1/S+97v9/2PbYfb9/7Xtuz3vKX2v687a/Y/kPbn7L9ut7yn2u35zHbn7B92ZD9jdNRPAd3paTzJN19BOv+E0k/ZPsFkmQ7k3SNFoV0GttXSHqhpPt6X3u8m/Lt4kWSPi/pIknvlvRH3v4fnO6S9KCkS7QYeb3T9o+klP5G0jslfTCldH5K6bvXbMfDkn5C0nMlXS/p921/b2/5xZKeJ+n5kl4r6VbbX7tmnasu1uJzeL6kX5f0Pkk/K+m4pBdLepvtb27fW0v65Xabr5T0o5JukiTbF0n6C0lvlfT1WuybH+h+iO1XalG6PynpGyT9vRajTxxESonHAR6SrpX00Jr3vF3SB3ZZdr+kJyU93nu8vrf8k5JuaZ//mKT/kVS2ry+XlNrveaZ9/h61/7/SgOyvkXRf7/Vz2nVcLOkFWpygF/SWv0vSHeu2acDP/StJv9g+f0mbvegtf1jSFWvWcYekd6ysI29fX9Bux4t67z8l6epd1vVLku5un18n6Z7eMkt6QNLr2tcfl/Ta3vJM0tOSLtv0sTjFByOeg3tE0kW2D3Of7OqU0oW9x/t6y+6U9Or2+asl3ZVSmq98/0WSzpf0Ji1OwnIfP/uh7klK6en26flajHIeTSk90XvvF7QYUeyL7atsf7ad8jwu6eVt5s4jKaWq9/rpNsN+PJJSqtvnz7T/fqm3/Jlunba/1fZf237I9le1GL11eS7RomgkSWnRLg/21nOZpD9oR5KPS3pUi3La934BU63DuEfS/0m6+ojW/yFJl9r+YS2G92ecZqWU6pTS70n6X7XThkP6b0lfZ/uC3te+SdJ/dT9yyEpsH5P0l1qMxL4xpXShpI9pcbJuym2S/lXSt6SUnqvF1KnL80VJl3ZvbKedl/a+9wFJN6xcKL4mpfSZoOznFIrngFJKX9HinsKt7Q3N59gu26v8u3tvzWyf13scG7j+p7S45/B+SV9IixvIe/ltSW+2fd5Btqf3cx+Q9BlJ72rzfpcW91+6X6F/SdLl7X2nvcwkHdNiili1N69fephsZ8EFkr4q6Unb3ybp53vLPirpO9vPspB0QoupZ+ekpLfa/g5p8csF268Kyn3OoXgOIaX0u5LeKOnXtDjBHpD0Bi3uZXR+Rovhfvf4996yj3jn3/Gs3qi+U4sh/h8PiPNRSY9J6n4D9aTtF+97o7YzX67F6OduSb+Rtn/t/+ftv4/Y/ofdVtBO1X5B0p+1ua6R9OED5jlbfqXN8YQWN6E/2C1IKX1Z0qu0uNH+iBZ/unCvFqNapZTulvQ7ku5qp2n/LOkq4UD4z96BM2hHdA9Kujal9HebznOuYcQDtGz/uO0L2+lwd//nsxuOdU6ieHAoK1PF/uOg0zzZ/twu67z2bGY/gyu1mAp/WdIrtPit4zN7fwsOgqkWgHCMeACEo3gAhKN4AISjeACEo3gAhKN4AIT7f9D3An8T9A2yAAAAAElFTkSuQmCC\n", | |
| "text/plain": [ | |
| "<Figure size 360x360 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| }, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "iVBORw0KGgoAAAANSUhEUgAAAR4AAAE2CAYAAAC3Nv2NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAQZUlEQVR4nO3cb4xld13H8c/3d865u5FuWtsasoi0T4yJQSXZXSiRav2XUBNDNSRGEhoNISVtfYJRGzUCJiIhYmIiTZtKsEoiWEMbCCjoA41I2aWDmCiGSLWkZZXYUnRpu8w953x9cOfcOXs7c+dMd+Z7fmf6fiUbZnfu3P3cO/e877l3lpq7CwAipbEHAHjxITwAwhEeAOEID4BwhAdAOMIDIFy57pMbGxuT+1n7bbfdNvYEZO7ee+8de8KLwqlTp2y3z60NT+c1n/+smmuePbhFh+ni02MveEFuaP9Dx5uxVxx9Fwvp9OnvHHvGkXfyZKPz53f//J7hcXc11zyrz73iJuX+jw3NTHccf0AbGxs6d+6c2rbNfnNKSSklnf7CGZ171Vm1bTv2pD11myWpbdvJbL799tt14sKGzp59MvvNZqaUkswWJw1N02T/WO5vPnPm6rWXHXTGI0mbm5vZf7PKcvvmNE2jzc3N7L9ZVVWpqipJi83z+TzrzWamsiw1m80kSXVdq67r7Dd397Ekzedz1XU94qK9mZlms5nKspS7L+/nnKWUNJvNVBTFnpcdHJ4pnD24+3Jj93Hum/sxd/fs477TfTylzTv9Plf9+3lqx99e+KkWgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMKVgy9YlnL3w9xy2VJKMrPlx0VRjLxob/2NKSWV5eBvyWi6zWY2mc0pbT/HTuFxIW1vNjMVRbF8bOeqezwMMfgRM5vNsg9P/xtzxx13SNJkNp+4IF04cWf2eyU97wCY2uayLLOPj5ktN5uZqqrK/n7ub97L4PDs50pzMpXNFwvpxIWNsWcceReLaT6Wp7h5ncHhqes6++KmlHTPPffozJkzuvvuu9W27SQ2F0WhM6fP6Oy5s2qaZuxJeyqKYnlK3bZt9pu7lwApJZ05fUbN2UZt2449a63+ZklqmmlsHvqScHB4Njc3sz+IVw+IKWwuy/KSB9fm5ubIi/ZWVZVms5mkxRPSfD4fedHeZrPZ8n6u61p1XY+8aD0zW25u21bz+Tz7wHfvsQ55GTs4PO6e/UEsbb/f0O3NfXN/3xT2StO7j6Vp38/dx7lv3s9GfpwOIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QjPyMxs7AlAuHLoBauqUtu2h7nlshVFsTyQU0oqy1LuPvKq3ZmZUkrLzUVRqKqq7DcXRfG8j6eyWdIlH+eqe2x0H5dlmf2TVEppuXkv+wrPFPQP4qF3wpj6D6Ypbi7LcjIHcqeqKpXl4If+aLrNXXimtHnPy617ptrY2HB31w2PPKzmmmcPaht28tPvkV7yzbFXHH3/d7V05efHXnHknTzZ6Pz57921QoPC07atmqbJ+nRaWpzqdS+3prK5KIrlGUPTNGrbdnKbm6YZedF63cuWbnNd19m/bdC9PEwpyd2Xj42c9TebmU6dOrVreAadu7Vtq/l8nv0NL8tyeaObptHm5mb2B3FVVcvNbdtmv7n/UlZahGc+n2e/eTabqSiK5UE8n8/HnrWWmenYsWPL8NR1rbqux5611n7e4xn8hkLOD6wp69+v7p79/dzf2H08pc3d76dg6pvXyf+dTABHDuEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhBofHzA5zx4vW6v06hft5Chv7zOySzVPZP8XNQ5VDLpRSUlmWcvfD3nNZiqJYfoOKolBVVZPbLCnrzWa23Nl9bGbZb04pLT/u9udsp825x6e/eS+DwzObzS5rVLQpbi6KYhIHRV9ZDnoIZaUsy0ntNjNVVTX2jAO1r3s/52c1aefTUTYfvNXNU9srsfkw7OeMbFB42rZVXdfZ3/DuJaGZqWkaNU2T/eaiKJbPvlPcXNe1mqYZedF63UuV7mxyKpvLslRKSe6uuq7Vtu3Ys9bqb97LvsKT+w0vy3L5WrhtW83n8+wPYmn7fZ6maSaxuXtmc/fl5pz130PrDuK6rkdetV73hngXnqZpst+cUlr+2vOyQ68094Nhqlbv1yncz1PY2De1vTs5Crehj3/HAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEGh8fMDnPHi1b/fjWz7O/n/sbu4ylt7n6fuylulobvLIdcKKWkqqrk7pc16rCllJY3vCgKzWaz7DcXRbHcnFLKfrOZKaXt56uiKEZcM4yZLXeamcqyzP5A7t/P3eb+/Z6j/TwJDQ5P7jd61RQ3F0UxiQO5j82HrwvPUTL41uT8LNzXLy6bD8fqs9rUNk9hrzT9zesMCk/btprP59nf+KIolqfRTdOorutJbK6qSpJU17Wapsl6c/eypXsGrutadV2PvGq97oyhKAq5+/J+ztkUN6eUBr8kHByeKRzE0vZ7JlPa3MWyC/wUlGUpd18GPnf993mmsLl7j2dq4Rn6Fse03gQBcCQQHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhXDrmQmcnMDnvLgeh2TmVzf+PUNnd7c9+8upHNh2M/GweFJ6Wk2Wwmd7+sYYctpbS84VPcXBSFjh07lvVmM1NK2yfKZVlO4oDoNpuZyrK85DbkqL85paSqqlQUxcir1lt9bKwz+IynLAddNBsppewfXKvYHKMoiuwP4lVT3LzO4Jq4e9bPxNLzT/Xath1xzTD9zVO4jyU2R+kHfQqbD/ylVtM0qus6+wO5KApVVSUzU9M0ms/n2X+zyrJcvlzp7uecN5vZ8n6WpLqu1TRN9pu7+9ndVde16roee9ZaZrZ8eeXums/nappm7FlrdS8Jh5wBDwpP983K+cElLb5Z7i4zU9u22R8Q0qUvY9u2zf6AkHTJmcNUNvcPhqZpsj+Iu8B34ZnCZndXURSDwjOtF+cAjgTCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOHKIRcyM6WU1LbtYe+5LGYmM1t+nPvmbuPq7919xFXrrd7HU9ssSSml7DenlCb1WJaefz+vMyg8KSXNZrOsv1HSpd+soihkZpPbnPsBIemSWJZlecnvc7Qay6ls7jaamaqqUlEUI69ab/WJdJ09w9M/KKbEzCa3OfeDYSfd2cOUsPlwDTnrGfxSCwAOiuV+Wg/g6JnGuRuAI4XwAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gwWbbwQTN72szOmdmNZvbl3ucfM7OfHHMjdkZ4RmRmbzKzR8zsW2b2X2b2V2b2uq3PvdPMPrTL1z1mZs9tfV3364/M7AYze8bMrtjha/7JzO40s+vNzHtf95iZ3XXYt3Vrw6636QV6naSfkvRyd3+1u/+Du3/fAV4/Dsmg/3c6Dp6ZvV3SXZLeJulTkjYlvV7SGyR9ZsBV/Iy7/+0O1/uEpDdK+pPen71S0vdL+nNJJ7b++Cp3r83stKS/N7MNd/+bF36LRnGdpMfc/Zmxh2B/OOMZgZldKel3JN3h7h9192fcfe7uH3f3X73Mq79f0q0rf3arpE+6+1OrF3b3RyT9q6RXDblyM/tFM/uMmf3+1kuc/zSzm3uff5mZfczMvmFmXzGzt279+esl/Yakn9860/rnXa7/LjN71MwumNmXzOxnd7ncWyT9saTXbl3fu8zspq3w7nT51Lvup8zsL8zs6iG3GQeP8IzjtZKOS3rwEK77zyT9iJl9j7Q44CS9SYsgPY+Z3SDplZK+0vuzb3Yv+XbxGklflnStpPdK+oBt/0ebPizpCUkv0+LM691m9uPu/teS3i3pI+5+hbv/0C7X/aikGyVdKeldkj5kZidXL+TuH9DibPHhret7x5q9kvTLkm6R9KNb256W9P49vgaHhPCM4xpJT7p7fRnX8dBWILpfb5Ukd39c0t9JevPW5X5C0jFJn1j5+ifN7DlJD0u6W9JD3Sfc/Sp3X/dy76vufp+7N1oE7aSkl27F7ocl/bq7X3T3L2pxVrJ6BrYrd3/A3c+7e+vuH5H075JePfTr13ibpN909yfc/duS3inpjWbG2w0jIDzjeErStZf5oL9lKxDdr/t6n7tf2+F5s6QPu/t85euvlXSFpF+RdJOkah9/9393H7j7s1sfXqHFmcQ33P1C77JflfTdQ6/YzG41sy92QdXibOzafWzbzXWSHuxd779JaiS99ACuG/tEeMbxsKRva3Hqfxg+KunlZvZjkn5Ou7zMcvfG3f9A0kVJtx/A33te0tVmdqL3Z6+Q9LXur1z3xWZ2naT7JN0p6Rp3v0rSv0g6iP/27uOSbl6J9XF3/9qeX4kDR3hG4O7/K+m3Jb3fzG4xs+8ws8rMbjaz9/YumszseO/XsYHX/4ykv5T0QS1eFj2yx5e8R9KvmdnxF3J7en/v45I+K+n3tvb+oKS3SOp+hP51Sddvve+0k5doEaf/kSQz+yUtzngOwj2SfncrbjKz7zKzNxzQdWOfCM9I3P19kt4u6be0ONAe1+KZ/qHexX5B0nO9X4/2PvfxlX/Hs/pG9f1avLz40wFzPqHFm63dT6C+ZWY37vtGbW++XouznwclvaP3Y/8Htv73KTP7wuoXuvuXJL1PizPCr0v6AUn/+AJ3rPpDSR+T9GkzuyDpc1q8SY4R8B97BxCOMx4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOH+H4vna0rsUOaeAAAAAElFTkSuQmCC\n", | |
| "text/plain": [ | |
| "<Figure size 360x360 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| }, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "iVBORw0KGgoAAAANSUhEUgAAAR4AAAE2CAYAAAC3Nv2NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAQCklEQVR4nO3db6hk913H8c/3d86ZXTSxoY2UW6vJIxWpfyDdZouNxn/QPJBGKYiFBqUUFrd5UlGDim0FaylWEKwEYqmpBVsrTUip9d8DxZpNJasVtFJsMCVpa7BplE1MMnPO+fpg7pl77uzeuWd27/2e37l5v+DSu3vnzn5m7pz3nJkNW3N3AUCkNPYAAC89hAdAOMIDIBzhARCO8AAIR3gAhCs3ffHixYv8XTuAq3LLLbfYQV/bGJ4O/60PgKHMDuzNyqHhcXe5u9q2zT5AZqaiKCRpMptTSkpp+Yq3bVu1bTvyosOx+fiZmVJKq4O4aZrsH8vd5u7zTQad8bRtq/l8nv0PqyzL1Q+raRrN5/Psf1hVVamqqtXmxWKR9WYzU1mWms1mkqS6rlXXdfabq6pSSknursViobqux561kZlpNpupLEu5++p+zllKSbPZbPXkv8ngl1pTOHvozs7MbPV57pv7Me/u55z179P+2XDO1h8HU3hcSNp3P0/p+BuCv9UCEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQLhyyIXMTGVZyt2Pe881SSnJzFafF0Ux8qLD9TemlFSWg34ko+o2m9lkNqe09xw7hceFtLfZzFQUxeqxnavu8TDEoEdMURRKKWUfHjNb/XCmvDl3/QOgLMtJHMjd5u5JNPfN/ceFmamqqkk9lg8z+KlqmyvNxdQ2T22vxOYoU9y8yaDwuLuapsm+uP2XV23bqm3byW1ummbkRYfrn5lNYXP3EqDb3DSN2rYdedVmU9089CXhoPC0bav5fJ79QdwdEGY2mc1lWa42N02j+Xw+9qRDVVWl2WwmSarrWovFYuRFh5vNZquX3nVdq67rsSdtZGarzW3barFYZB/47nE85GXs4DOe7iN37i4zm8zm/r4p7JX2Nk/lPpamfT93n+e+eZuN+b+TCeDEITwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4RmZmY09AQhXDrlQSklVValt2+Pec02KolgdyCkllWUpdx951cHMTCml1eaiKFRVVfabi6K47POpbJa07/NcdY+N7vOyLLN/kkoprTYfZqvwTEH/IB56J4yp/2Ca0uZz586NPeOq3XvvvSrLQQ/9UXWPjS48U9p8mEG35OzZF9U0O9c0CCdNqbPfeVGnVY89ZCsvqNSZM2fGnrGVnZ0dPfTQQ2PPOFKDwtM0O3r44aeyPp2Wlmdm3cuttm3VNE32m4uiWJ36N02jtm0nsfn8+VqnL9V630ceUdM0Y0/aqHvZcv78eV1/6aIuXLiQ/dsG3cvDlJLOnDmjuq4ntfkwg8/dFotF9je8LMvVeyZN02g+n2d/EFdVtdrctm32m9dPpZum0WKxyH7zbDZb/brbnDMz06lTp1YHcV3Xquu8zy63eY9n8BsKOT+wpqx/v7p79vfz+sapbp6CqW/eJP93MgGcOIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEGxweMzvOHS9Z6/frFO7nKWzsM7N9m6eyf4qbhyoHX7As5e7HueWaFUWx+gEVRaGqqia3WVLWm81stVPa25/75pT2nmP7+3N1pc25x2d98yaDwzObza560BhSSpPbXBTFJA6KvrIc/BDKRlmWk9tdVdXYE47UVvd+zs9q0pVPR9l89NY3T22vxObjsM0Z2eDwzOfz7G94SkllWcrM1DSNmqbJfnNRFKtn3ylt7tR1raZpRlxzuPWXh1PZXJbl6qXLYrFQ27Yjr9psffMmg8NT13X2N7wsy9Vr4bZttVgssj+Ipb3X703TTGJz/5mt25yz9Wfiuq5V1/VIa4bp3hDvDuKmabLfnFJafRx62aFXmvvBMFXr9+sU7ucpbOyb2t4rOQm3oY//jgdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCDQ6PmR3njpes/v1qZtnfz+sbp7o5d1PcLA3fWQ69wqqq5O5XPShCSml1w4ui0Gw2y35zURSrzSml7DebmVLae74qimLENcOY2b6dZVlmfyCv389lWe77dY62eRLaKjxTklLK/ge1riiKSRzIfWyOUZaDD9VJGHxrcn4W7usXl83Ho7/3zGvPjLhkS9dLt1+axn0sTe9xIR3DS635fJ79jS+KYnUa3TSN6rqexObubLKuazVNk/Xm9ZctFx65oLquR1x0ODNTWZY6f+m8pIur+zln3ebuvp7C5pTS4JeEg8MzhYNY2nvPpG3byWzuYtm2rRaLxdhzttIFPnf9Z+IpbO7e45laeIa+xTGtN0EAnAiEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMKVQy9oZse548h0O81sEpv7G9l8PNY3svl4bLNxcHhms5nc/apHRUgprW54Smlym4ui0KlTp7LebGZKae9EuSzLSRwQ65v7v87R+uaqqlQUxYiLDre+eZPB4SnLwRfNQkop+wfXOjbHKIoi+4N43RQ3bzK4Ju6e9TOxdPmpXtu2I64Zpr95CvextP+lFpuPTz/oU9h8LC+15vN59gdyURSqqkpmpqZptFgssv9hlWW5ernSNI3qus56s5nte+at61pN02S/uX/GXte16roecdHhzGzfy6vFYqGmaUZetVlKSVVVDToDHhye3A8IafnDcneZmdq2zf6AkPYfFG3bZn9ASLrsrHIKm/sHQ9M02R/EXeC78Exhs7urKIpB4ZnWi3MAJwLhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHDl0AumlNS27XFuuWZmJjNbfZ775m7j+q/dfcRVm/Xv4+7XU9ucUsp+c0rpss05P5aly+/nTQaHZzabZf2Dkvb/sIqikJlNbnPuB4SkfbG8++67R1yyne5+Lsty323I0fqTUlVVKopixEWHW9+8yaHhMTMVxdd09uzONQ/DSVLq7HdL3/LsP409ZCsvJOnWW28de8ZWdnZ2VmdpUzDkrGfQGc/nPnda0jPXugcnygfHHoAJs9xP6wGcPNM4dwNwohAeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA8my5Y+bGbPmNk/mtltZvbF3tcfN7MfH3MjrozwjMjM3mJmj5rZs2b2NTP7jJm9Yfdr7zazjx7wfY+b2fO739d9/L6ZnTWz58zsuit8zz+b2TvM7GYz8973PW5m9xz3bd3dcOBtukpvkPQTkl7t7q9z97939+86wuvHMRn87/HgaJnZOyXdI+mcpL+UNJf0RklvkvTZAVfxk+7+N1e43iclvVnSH/V+7zWSvkfSn0i6fve3b3D32sxeK+nvzOyiu//11d+iUdwk6XF3f27sIdgOZzwjMLOXSfpNSefd/ZPu/py7L9z9U+7+S9d49fdLumvt9+6S9Ofu/vT6hd39UUn/JukHhly5mf2cmX3WzH5n9yXOf5rZHb2vv8rMHjKzb5jZl8zs7bu//0ZJvyrpZ3bPtP7lgOu/x8weM7NLZvYFM/upAy73Nkl/KOn1u9f3HjO7fTe8V7p86l3302b2p2b28iG3GUeP8Izj9ZJOS3rgGK77jyX9kJl9u7Q84CS9RcsgXcbMzkp6jaQv9X7vf7qXfAe4VdIXJd0o6f2SPmR7//rTxyQ9KelVWp55vdfMftTd/0LSeyV93N2vc/fvP+C6H5N0m6SXSXqPpI+a2WX/Cp27f0jLs8ULu9f3rg17JeluSXdK+uHdbc+If1RoNIRnHK+Q9HV3r6/hOh7cDUT38XZJcvcnJP2tpLfuXu7HJJ2S9Om17/+6mT0v6YKkP5D0YPcFd7/B3Te93Puyu9/n7o2WQduR9Mrd2P2gpF9x9xfc/fNanpWsn4EdyN0/4e5fdffW3T8u6T8kvW7o929wTtKvufuT7v6ipHdLerOZ8XbDCAjPOJ6WdOM1Pujv3A1E93Ff72v3ay88b5X0MXdfrH3/jZKuk/SLkm6XVG3xZ/9X94m7/9/up9dpeSbxDXe/1LvslyV929ArNrO7zOzzXVC1PBu7cYttB7lJ0gO96/13SY2kVx7BdWNLhGccFyS9qOWp/3H4pKRXm9mPSPppHfAyy90bd/9dSS9I+oUj+HO/KunlZnZ97/e+Q9JXuj9y0zeb2U2S7pP0DkmvcPcbJP2rpGH/1wWbPSHpjrVYn3b3rxz6nThyhGcE7v6/kn5D0gfN7E4z+yYzq8zsDjN7f++iycxO9z5ODbz+5yT9maQPa/my6NFDvuV9kn7ZzE5fze3p/blPSHpY0m/v7v0+SW+T1P0V+lOSbt593+lKvlnLOP23JJnZz2t5xnMU7pX0W7txk5l9q5m96YiuG1siPCNx9w9IeqekX9fyQHtCy2f6B3sX+1lJz/c+Hut97VNr/x3P+hvV92v58uIjA+Z8Wss3W7u/gXrWzG7b+kbtbb5Zy7OfByS9q/fX/p/Y/d+nzeyy/3sKd/+CpA9oeUb4lKTvlfQPV7lj3e9JekjSX5nZJUmPaPkmOUbAP/YOIBxnPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwv0/QQEZOjirj98AAAAASUVORK5CYII=\n", | |
| "text/plain": [ | |
| "<Figure size 360x360 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "vidx = 2\n", | |
| "show_results(vidx=vidx)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "ed338a97", | |
| "metadata": {}, | |
| "source": [ | |
| "## pred" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "15bc44f5", | |
| "metadata": {}, | |
| "source": [ | |
| "### PCFGs" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 251, | |
| "id": "52a4aa3a", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "rule_prob, root_prob, term_prob = pcfgs" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "1c938386", | |
| "metadata": {}, | |
| "source": [ | |
| "#### Binary rules" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 252, | |
| "id": "18957cc9", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "torch.Size([4, 6, 2, 9, 9]) torch.Size([4, 6])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "lr = rule_prob[:, :, 0] # (B, NT, NT_T, NT_T)\n", | |
| "ab = rule_prob[:, :, 1] # (B, NT, NT_T, NT_T)\n", | |
| "\n", | |
| "lr_lp = lr.logsumexp((-1, -2))\n", | |
| "ab_lp = ab.logsumexp((-1, -2))\n", | |
| "print(rule_prob.shape, lr_lp.shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 253, | |
| "id": "fabb1c9b", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "tensor([0.3788, 0.3788, 0.3788, 0.3788], device='cuda:0') uniform: 4.394449154672439\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "part_prob = rule_prob.flatten(-2).log_softmax(-1)\n", | |
| "part_ent = -(part_prob.exp() * part_prob).sum(-1)\n", | |
| "print(part_ent.mean((-1, -2)), \"uniform:\", np.log(part_prob.shape[-1]))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 254, | |
| "id": "4761b8b9", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "if ipcfg_la:\n", | |
| " print((rule_prob.flatten(-2) == part_prob).all())" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 255, | |
| "id": "7c3b2b93", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "tensor([[1.0000, 1.0000],\n", | |
| " [1.0000, 1.0000],\n", | |
| " [1.0000, 1.0000],\n", | |
| " [1.0000, 1.0000],\n", | |
| " [1.0000, 1.0000],\n", | |
| " [1.0000, 1.0000]], device='cuda:0')\n", | |
| "tensor([[1.0517e+00, 1.0517e+00],\n", | |
| " [5.1468e-01, 5.1468e-01],\n", | |
| " [2.3368e-06, 2.3368e-06],\n", | |
| " [7.0624e-01, 7.0624e-01],\n", | |
| " [2.9700e-06, 2.9700e-06],\n", | |
| " [5.6108e-07, 5.6108e-07]], device='cuda:0')\n", | |
| "tensor(0.3788, device='cuda:0') tensor(nan, device='cuda:0')\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "lr_ab = torch.stack((lr_lp, ab_lp), -1)\n", | |
| "print(lr_ab[0].exp())\n", | |
| "print(part_ent[0])\n", | |
| "\n", | |
| "fulcrum = 0.95\n", | |
| "above = part_ent[lr_ab.exp() >= fulcrum].mean()\n", | |
| "below = part_ent[lr_ab.exp() < fulcrum].mean()\n", | |
| "print(above, below)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 256, | |
| "id": "acb68a7d", | |
| "metadata": { | |
| "scrolled": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# rule_prob[0, 1, 0].flatten(-2).exp() #.sum()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 257, | |
| "id": "b37cb133", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# x_entmax = entmax_bisect(rule_prob[1, 1, 0].flatten(-2), 1.05)\n", | |
| "# x_entmax[x_entmax.isnan()] = 0\n", | |
| "# x_entmax" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "10732388", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 258, | |
| "id": "03d3d9ab", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "tensor([0.0094, 0.0094, 0.0094, 0.0094], device='cuda:0')\n", | |
| "tensor([0.6931, 0.6931, 0.6931, 0.6931], device='cuda:0')\n", | |
| "tensor([2.9802e-08, 2.9802e-08, 2.9802e-08, 2.9802e-08], device='cuda:0')\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "if not ipcfg_la:\n", | |
| " ent = -(rule_prob.exp() * rule_prob).sum(2).mean((-1, -2, -3))\n", | |
| " print(ent)\n", | |
| "\n", | |
| " rule_lp = rule_prob.log_softmax(2)\n", | |
| " ent = -(rule_lp.exp() * rule_lp).sum(2).mean((-1, -2, -3))\n", | |
| " print(ent)\n", | |
| "\n", | |
| " rule_ent = -(lr_ab.exp() * lr_ab).sum(-1)\n", | |
| " # print(rule_ent)\n", | |
| " rule_ent = rule_ent.mean(-1)\n", | |
| " print(rule_ent)\n", | |
| "else:\n", | |
| " lr_ent = -(lr.exp() * lr).sum((-1, -2)).mean(-1) # (B,)\n", | |
| " ab_ent = -(ab.exp() * ab).sum((-1, -2)).mean(-1) # (B,)\n", | |
| " print(lr_ent, ab_ent)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 259, | |
| "id": "2f8a6460", | |
| "metadata": { | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "tensor([1., 1., 1., 1.], device='cuda:0')\n", | |
| "tensor([1., 1., 1., 1.], device='cuda:0')\n", | |
| "tensor([-0., -0., -0., -0.], device='cuda:0')\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "if not ipcfg_la:\n", | |
| " print(lr_lp.exp().mean(-1))\n", | |
| " print(ab_lp.exp().mean(-1))\n", | |
| "\n", | |
| " lr_prefer = (lr_lp - np.log(lr_lp.shape[-1])).logsumexp(-1)\n", | |
| " ab_prefer = (ab_lp - np.log(ab_lp.shape[-1])).logsumexp(-1)\n", | |
| " lr_ab_prefer = torch.stack([lr_prefer, ab_prefer], -1)\n", | |
| " ent = -(lr_ab_prefer.exp() * lr_ab_prefer).sum(-1)\n", | |
| " print(ent)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "d9c8e1d1", | |
| "metadata": {}, | |
| "source": [ | |
| "#### Start rules" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 260, | |
| "id": "31c9f6e1", | |
| "metadata": { | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[[0.08 0.07 0.26 0.07 0.26 0.26]\n", | |
| " [0.08 0.07 0.26 0.07 0.26 0.26]\n", | |
| " [0.08 0.07 0.26 0.07 0.26 0.26]\n", | |
| " [0.08 0.07 0.26 0.07 0.26 0.26]]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "with np.printoptions(precision=2, suppress=True):\n", | |
| " if not ipcfg_la:\n", | |
| " print(np.exp(root_prob.cpu().numpy()))\n", | |
| " else:\n", | |
| " B = root_prob.shape[0]\n", | |
| " root_p = root_prob.view((B, 2, -1))\n", | |
| " print(np.exp(root_p.cpu().numpy()), root_p.shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "38d61ee7", | |
| "metadata": {}, | |
| "source": [ | |
| "#### Pre-terminal rules" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 261, | |
| "id": "73b147f9", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.66 0. 0. ] [1]\n", | |
| "[0.66 0. 0. ] [1]\n", | |
| "[0.66 0. 0. ] [1]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0. 1. 0.] [2]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0. 1. 0.] [2]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0. 1. 0.] [2]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0.34 0. 1. ] [0]\n", | |
| "[0. 1. 0.] [2]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "v_seq = batch_dict[\"image\"]\n", | |
| "\n", | |
| "isample = 0\n", | |
| "\n", | |
| "x = v_seq[isample].unsqueeze(-1).cpu().numpy()\n", | |
| "if not ipcfg_la:\n", | |
| " x_pv = term_prob[isample].cpu().numpy()\n", | |
| "else: # B, L, T\n", | |
| " B, L, T = term_prob.shape\n", | |
| " x_pv = term_prob[isample].view((L, 2, -1))\n", | |
| " x_pv = x_pv.cpu().numpy()\n", | |
| "\n", | |
| "with np.printoptions(precision=2, suppress=True):\n", | |
| " for pv, xx in zip(x_pv, x):\n", | |
| " print(np.exp(pv), xx)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 262, | |
| "id": "6194323b", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# np.exp(x_pv)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "e1680721", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "3b24330b", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "55465616", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "a79a0cf7", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.8.12" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment