Skip to content

Instantly share code, notes, and snippets.

@zhaoyanpeng
Created November 17, 2022 22:45
Show Gist options
  • Select an option

  • Save zhaoyanpeng/affab669325e0d701db0fc35f5d04381 to your computer and use it in GitHub Desktop.

Select an option

Save zhaoyanpeng/affab669325e0d701db0fc35f5d04381 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "897059d7",
"metadata": {},
"source": [
"## Env"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c68042e9",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import glob\n",
"import copy\n",
"import subprocess\n",
"import numpy as np\n",
"from tqdm import tqdm\n",
"from nltk import word_tokenize\n",
"from collections import Counter, defaultdict\n",
"\n",
"import re\n",
"import io\n",
"import os, sys\n",
"import requests\n",
"import PIL\n",
"from PIL import Image, ImageDraw, ImageFont\n",
"import numpy as np\n",
"\n",
"import random\n",
"from os import path\n",
"from datetime import datetime\n",
"from dataclasses import dataclass\n",
"import nltk\n",
"\n",
"import yaml\n",
"import torch\n",
"from omegaconf import OmegaConf\n",
"from hydra import initialize, initialize_config_module, initialize_config_dir, compose\n",
"\n",
"import pandas as pd\n",
"from functools import partial\n",
"from datasets import load_dataset\n",
"from braceexpand import braceexpand"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8add443d",
"metadata": {},
"outputs": [],
"source": [
"vreason_path = \"/mnt/yann/code/vreason\"\n",
"paths = [vreason_path]\n",
"for ppath in paths:\n",
" if ppath not in sys.path:\n",
" sys.path.insert(0, ppath)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "265999ce",
"metadata": {},
"outputs": [],
"source": [
"home = \"/mnt/zhaoyanpeng\"\n",
"repo_root = f\"{home}/code/clevr-ed/caption\"\n",
"dalle_droot = f\"{home}/data/dallemini\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d29cfd0b",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "eec4580b",
"metadata": {},
"outputs": [],
"source": [
"# %reload_ext autoreload"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a865a7b8",
"metadata": {},
"outputs": [],
"source": [
"import vreason\n",
"from vreason.module import PretrainedVQGAN\n",
"from vreason.data import build_clevr_image_text_data\n",
"from vreason.monitor import DalleMonitor\n",
"from vreason.module import ENCODER_HEADS_REGISTRY\n",
"from vreason.module import DECODER_HEADS_REGISTRY\n",
"from vreason.model import build_main_model"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "4a40a022",
"metadata": {},
"outputs": [],
"source": [
"from matplotlib import pyplot as plt\n",
"from matplotlib import image as ireader\n",
"import matplotlib.patches as patches\n",
"from matplotlib.patches import Rectangle\n",
"import matplotlib.gridspec as gridspec\n",
"plt.rcParams[\"savefig.bbox\"] = 'tight'"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "3811e008",
"metadata": {},
"outputs": [],
"source": [
"# %reload_ext autoreload\n",
"from ipynb.fs.full.ipcfg_analysis_parse_functions import (\n",
" convex_hull,\n",
" yxyx_to_xyxy,\n",
" bias_by_width,\n",
" estimate_overlap,\n",
" show_image_bboxes,\n",
" visualize_2d_parse,\n",
" visualize_partial_parse_toy,\n",
" visualize_partial_parse_clevr,\n",
" obj_name_and_box_from_annotations,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "4f4326e8",
"metadata": {},
"outputs": [],
"source": [
"# from vreason.module.decoder.ipcfg_head import IPCFG2DDecHead, IPCFG1DDecHead\n",
"# DECODER_HEADS_REGISTRY._obj_map.pop(\"IPCFG2DDecHead\", \"\")\n",
"# DECODER_HEADS_REGISTRY.register(IPCFG2DDecHead)\n",
"# DECODER_HEADS_REGISTRY._obj_map.pop(\"IPCFG1DDecHead\")\n",
"# DECODER_HEADS_REGISTRY.register(IPCFG1DDecHead)"
]
},
{
"cell_type": "markdown",
"id": "2d9bdf03",
"metadata": {},
"source": [
"## Conf for loading data"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "9e5e66a4",
"metadata": {
"code_folding": [
7,
16,
49
]
},
"outputs": [],
"source": [
"alias_root=\"/mnt/yann/model/ipcfg\"\n",
"model_root=\"/mnt/yann/model/ipcfg\"\n",
"\n",
"# pre-trained VQGAN\n",
"vq_root = \"/mnt/yann/model/vqgan_tuned/logs\"\n",
"\n",
"# 256 x 16384\n",
"class VQConf:\n",
" vq_time = \"2022-09-10T00-25-13\"\n",
" vq_name = \"clevr_vqgan_128e_16384\"\n",
" vq_file = \"last.ckpt\"\n",
" vocab_size = 16384\n",
" max_vis_len = 256\n",
" gumbel = False\n",
"\n",
"# 256 x 1024\n",
"class VQConf:\n",
" vq_root = vq_root\n",
" vq_time = \"2022-09-10T05-13-14\"\n",
" vq_name = \"clevr_vqgan_128e_1024\"\n",
" vq_file = \"last.ckpt\"\n",
" vocab_size = 1024\n",
" max_vis_len = 256\n",
" gumbel = False\n",
" \n",
"vq = VQConf()\n",
"\n",
"text_fold = \"CAPTION_4.0_25_captions\"\n",
"data_root = f\"/mnt/yann/data/CLEVR_text/{text_fold}\"\n",
"more_root = f\"/mnt/yann/data/CLEVR_v1.0_320x320_bbox\"\n",
"\n",
"model_sign = vq.vq_file.split(\".\", 1)[0]\n",
"data_root = f\"/mnt/yann/data/dallemini/{vq.vq_time}_{vq.vq_name}/{model_sign}/{text_fold}\"\n",
"\n",
"data_name = \"train/\\{001..007\\}.parquet\"\n",
"eval_name = \"val/\\{001..002\\}.parquet\"\n",
"\n",
"epoch = 1\n",
"bsize = 2\n",
"\n",
"min_obj_num = 1\n",
"max_obj_num = 1\n",
"\n",
"model_name = \"ipcfg_test\"\n",
"model_file = \"00002560.pth\"\n",
"\n",
"alias_name = \"ipcfg_eval_toy\"\n",
"alias_odir = \"toy_5x5\"\n",
"\n",
"class DataConf:\n",
" data_root = data_root\n",
" more_root = more_root\n",
" data_name = data_name \n",
" eval_name = eval_name\n",
" txt_vocab_name = \"txt_vocab_is_null\"\n",
" min_obj_num = min_obj_num\n",
" max_obj_num = max_obj_num\n",
" min_train_vid = 0\n",
" max_train_vid=9999\n",
" min_eval_vid=0\n",
" max_eval_vid=999\n",
" train_samples = None \n",
" eval_samples = \"[0,128]\" \n",
" test_samples = None\n",
" batch_size = bsize\n",
" vis_vocab_size=3\n",
" obj_fake_hw = \"[[3,1],[1,4]]\"\n",
"\n",
"dt = DataConf()\n",
"\n",
"verbose = True\n",
"training = False\n",
"worker = \"IPCFG\"\n",
"Monitor = \"DalleMonitor\""
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "7a64ac25",
"metadata": {
"code_folding": [
0,
11,
32
]
},
"outputs": [],
"source": [
"base_overrides = [\n",
" \"+model/vq=default\",\n",
" \"+model/embedder=dummy\",\n",
" \"+model/encoder=dummy\",\n",
" \"+model/decoder=ipcfg\",\n",
" \"+model/loss=ipcfg\",\n",
"# \"+optimizer=dummy\",\n",
" \"+data=clevr_dalle\",\n",
" \"+running=default\"\n",
"]\n",
"\n",
"special_overrides = [\n",
" f\"model.vq.gumbel={vq.gumbel}\",\n",
" f\"model.vq.model_root={vq.vq_root}\",\n",
" f\"model.vq.model_name={vq.vq_name}\",\n",
" f\"model.vq.model_file={vq.vq_file}\",\n",
" f\"model.vq.model_time={vq.vq_time}\",\n",
" f\"model.vq.vocab_size={vq.vocab_size}\",\n",
" f\"model.vq.max_vis_len={vq.max_vis_len}\",\n",
" f\"model.vq.load=false\",\n",
" \n",
" f\"+model.encoder.z_dim=0\",\n",
" \n",
" f\"model.decoder.NT=8\",\n",
" f\"model.decoder.T=3\",\n",
" f\"model.decoder.s_dim=256\",\n",
" \n",
" \n",
" f\"running.epochs={epoch}\",\n",
" f\"running.batch_size={bsize}\",\n",
"]\n",
"\n",
"extra_overrides = [\n",
" f\"port=9990\",\n",
" f\"num_gpus=1\",\n",
" f\"verbose={verbose}\",\n",
" f\"eval={not training}\",\n",
" f\"mode=dp\",\n",
" f\"num_proc=0\",\n",
" f\"seed=1213\",\n",
" f\"rank=0\",\n",
" f\"autocast=false\",\n",
" f\"worker={worker}\",\n",
" f\"alias_odir={alias_odir}\",\n",
" f\"alias_root={alias_root}\",\n",
" f\"alias_name={alias_name}\",\n",
" f\"model_root={model_root}\",\n",
" f\"model_name={model_name}\",\n",
" f\"model_file={model_file}\",\n",
" f\"check_unused_param=false\",\n",
" f\"data.data_root={dt.data_root}\",\n",
" f\"data.more_root={dt.more_root}\",\n",
" f\"data.data_name={dt.data_name}\",\n",
" f\"data.eval_name={dt.eval_name}\",\n",
" f\"data.train_samples={dt.train_samples}\",\n",
" f\"data.eval_samples={dt.eval_samples}\",\n",
" f\"data.test_samples={dt.test_samples}\",\n",
" f\"data.batch_size={dt.batch_size}\",\n",
" \n",
" f\"data.txt_vocab_name={dt.txt_vocab_name}\",\n",
" f\"data.min_obj_num={dt.min_obj_num}\",\n",
" f\"data.max_obj_num={dt.max_obj_num}\",\n",
"# f\"data.={dt.}\",\n",
"# f\"data.={dt.}\",\n",
"# f\"data.={dt.}\",\n",
"# f\"data.={dt.}\",\n",
"# f\"data.={dt.}\",\n",
" \n",
" f\"data.vis_vocab_size={dt.vis_vocab_size}\",\n",
" f\"+data.vis_fake_hw=5\",\n",
" f\"+data.obj_fake_hw={dt.obj_fake_hw}\",\n",
" f\"+data.obj_delt_hw=[]\",\n",
" \n",
" f\"+data.require_txt=false\",\n",
" f\"+data.use_preencoded_pandas=true\",\n",
" f\"+data.txt_special_token={{bos:\\\"<TXT|BOS>\\\",pad:\\\"<PAD>\\\",unk:\\\"<UNK>\\\"}}\",\n",
" f\"+data.vis_special_token={{bos:\\\"<VIS|BOS>\\\",unk:\\\"<UNK>\\\"}}\",\n",
"]\n",
"overrides = base_overrides + extra_overrides + special_overrides"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "7fcc5316",
"metadata": {
"code_folding": [
1
]
},
"outputs": [],
"source": [
"abs_config_dir=os.path.abspath(f\"{vreason_path}/configs\")\n",
"with initialize_config_dir(config_dir=abs_config_dir):\n",
" cfg = compose(config_name=\"default\", overrides=overrides)\n",
"# print(OmegaConf.to_yaml(cfg))\n",
"mcfg = cfg.model\n",
"echo = print"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "97de68a4",
"metadata": {},
"outputs": [],
"source": [
"# print(OmegaConf.to_yaml(cfg))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "50713902",
"metadata": {
"code_folding": [
0
]
},
"outputs": [],
"source": [
"def bias_by_width(argmax_list):\n",
" row = col = other = total = nsample = 0\n",
" for pred in argmax_list:\n",
" pred = yxyx_to_xyxy(pred)\n",
" nsample += 1\n",
" total += len(pred)\n",
" for b in pred:\n",
" x1, y1, x2, y2 = b\n",
"# y1, x1, y2, x2 = b\n",
" if x2 - x1 == 1:\n",
" col += 1\n",
" if y2 - y1 == 1:\n",
" row += 1\n",
" if x2 - x1 != 1 and y2 - y1 != 1:\n",
" other += 1\n",
" ratio_row, ratio_col, ratio_other = [x / total for x in [row, col, other]]\n",
" print(\n",
" f\"Biased by \" +\n",
" f\"row {ratio_row:.3f} ({row / nsample:.3f}) \" +\n",
" f\"col {ratio_col:.3f} ({col / nsample:.3f}) \" +\n",
" f\"other {ratio_other:.3f} ({other / nsample:.3f}) # sample {nsample}\"\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "55398be7",
"metadata": {},
"source": [
"## Infer"
]
},
{
"cell_type": "markdown",
"id": "ac7bba21",
"metadata": {},
"source": [
"### Load data"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "f0f1d19b",
"metadata": {},
"outputs": [],
"source": [
"cfg.data.batch_size = 4\n",
"cfg.data.obj_fake_hw = [[4,1],[1,3]]\n",
"convas_size = (16 * 5 + 1,) * 2\n",
"\n",
"linear = True\n",
"if linear:\n",
" grid_size = [6,6]\n",
" cfg.data.vis_fake_hw=grid_size\n",
" cfg.data.obj_fake_hw=[[4,1],[1,3]] #[[4,1],[1,3]] #\n",
" cfg.data.obj_delt_hw=[]\n",
" cfg.data.eval_samples=1024\n",
" cfg.model.decoder.grid_size=grid_size\n",
" \n",
" convas_size = (16 * grid_size[0] + 1, 16 * grid_size[1] + 1)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "500a7b7e",
"metadata": {
"code_folding": [
0
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Err: [Errno 2] No such file or directory: '/mnt/yann/data/dallemini/2022-09-10T05-13-14_clevr_vqgan_128e_1024/last/CAPTION_4.0_25_captions/vis_vocab_is_null'\n",
"Err: [Errno 2] No such file or directory: '/mnt/yann/data/dallemini/2022-09-10T05-13-14_clevr_vqgan_128e_1024/last/CAPTION_4.0_25_captions/txt_vocab_is_null'\n",
"Load 500 (2000) batches (main).\n"
]
}
],
"source": [
"dataloader, evalloader, testloader, encoder_vocab, decoder_vocab, cate_vocab = \\\n",
" eval(f\"build_{cfg.data.name.lower()}_image_text_data\")(\n",
" cfg.data, not cfg.eval, echo, ddp_mode=(cfg.mode == \"ddp\")\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "20e1c7ce",
"metadata": {},
"outputs": [],
"source": [
"# for ix, x in enumerate(dataloader):\n",
"# # print(x)\n",
"# if ix == 1:\n",
"# break"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "0e46f7e5",
"metadata": {
"code_folding": [
0
]
},
"outputs": [],
"source": [
"def check_data(dataloader, k=3):\n",
" image = list()\n",
" for ix, x in enumerate(dataloader):\n",
"# print(x[\"image\"])\n",
"# if ix == 1:\n",
"# break\n",
" image.append(x[\"image\"])\n",
" \n",
" image = np.concatenate(image, axis=0)\n",
" return image, np.eye(k)[image]\n",
"image_all, image_onehot = check_data(dataloader)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "0297a25a",
"metadata": {},
"outputs": [],
"source": [
"x1 = image_onehot.mean(0)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "3ac3c67d",
"metadata": {
"code_folding": []
},
"outputs": [],
"source": [
"# x1"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "b90f6569",
"metadata": {
"code_folding": []
},
"outputs": [],
"source": [
"# x2 = image_onehot.mean(0)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "83c8ed6c",
"metadata": {},
"outputs": [],
"source": [
"# x2"
]
},
{
"cell_type": "markdown",
"id": "bc6355a6",
"metadata": {},
"source": [
"### Parse"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "7d7834ba",
"metadata": {
"code_folding": [
0,
13
]
},
"outputs": [],
"source": [
"def make_batch(sample, device):\n",
" image = sample.pop(\"image\") #[\"image\"]\n",
" image = torch.from_numpy(image).to(device)\n",
" if len(image.shape) == 4: # can be pre-encoded image token sequences\n",
" image = image.permute(0, 3, 1, 2)\n",
"\n",
" text = sample.pop(\"text\") #[\"text\"]\n",
" text = torch.from_numpy(text).to(device)\n",
"\n",
" text_mask = text == encoder_vocab.PAD_IDX\n",
"\n",
" return {\"image\": image, \"text\": text, \"text_mask\": text_mask, \"vfile\": sample.pop(\"vfile\")} #[\"vfile\"]}\n",
"\n",
"def infer_batch(model, device, kbatch=4):\n",
" ibatch, kbatch = 0, kbatch\n",
" for ibatch, batch in enumerate(dataloader):\n",
" if ibatch != kbatch:\n",
" continue\n",
" # print(batch, batch[\"image\"].dtype)\n",
" batch_dict = make_batch(batch, device)\n",
" batch_dict.update(batch)\n",
" # print(batch_dict, batch_dict[\"image\"].dtype)\n",
" break\n",
" \n",
" v_seq = batch_dict[\"image\"]\n",
"\n",
" mean = lvar = None\n",
" kwargs = {\n",
" \"infer\": True, \"auto_infer\": False, \"exclude_trivial\": True,\n",
" \"require_marginal\": False, \"marginal_as_dict\": True, \"mbr\": False,\n",
" } \n",
" \n",
" with torch.no_grad():\n",
" ll, kl, targets, argmax, marginal, dec_extra = model.decoder_head(\n",
" None, None, v_seq=v_seq, mean=mean, lvar=lvar, **kwargs\n",
" )\n",
" \n",
" return batch_dict, argmax, model.decoder_head.pcfgs"
]
},
{
"cell_type": "markdown",
"id": "75b96a96",
"metadata": {
"heading_collapsed": true
},
"source": [
"### Load model"
]
},
{
"cell_type": "markdown",
"id": "256c0628",
"metadata": {},
"source": [
"### Model A"
]
},
{
"cell_type": "code",
"execution_count": 238,
"id": "d57a4172",
"metadata": {},
"outputs": [],
"source": [
"# cfg.model_name = \"ipcfg_test_662341_12_s8765_1585_50_85_3_1\"\n",
"# cfg.model_name = \"ipcfg_test_661341_12_s8765_1585_50_85_3_1\"\n",
"\n",
"\n",
"cfg.model_name = \"ipcfg_test_664113_12_s8765_1585_50_85_5_1\" # best\n",
"# cfg.model_name = \"ipcfg_test_664123_12_s8765_1585_50_85_5_1\" # best\n",
"\n",
"# cfg.model_name = \"ipcfg_test_661431_12_s8765_1585_50_85_5_1\"\n",
"# cfg.model_name = \"ipcfg_test_661432_12_s8765_1585_50_85_5_1\"\n",
"\n",
"# cfg.model_name = \"ipcfg_test_662432_12_s8765_1585_50_85_5_1\"\n",
"# cfg.model_name = \"ipcfg_test_663224_12_s8765_1585_50_85_5_1\"\n",
"\n",
"# cfg.model_name = \"ipcfg_test_662342_12_s8765_1585_50_85_5_1\"\n",
"# cfg.model_name = \"ipcfg_test_664223_12_s8765_1585_50_85_5_1\"\n",
"\n",
"\n",
"# cfg.model_name = \"ipcfg_test_664223_12_s8765_1585_50_85_5_1_r\"\n",
"\n",
"# cfg.model_name = \"ipcfg_test_662432_12_s8765_1585_50_85_5_1_r\"\n",
"\n",
"\n",
"# cfg.model_name = \"ipcfg_test_661431_12_s8765_1585_50_85_5_1_r\"\n",
"# cfg.model_name = \"ipcfg_test_664113_12_s8765_1585_50_85_5_1_r\"\n",
"cfg.model_name = \"ipcfg_test_664113_12_s8765_1585_50_85_5_1_r_e50\"\n",
"# cfg.model_name = \"ipcfg_test_661431_12_s8765_1585_50_85_5_1_r_e50\" # best\n",
"\n",
"\n",
"\n",
"# cfg.model_name = \"ipcfg_test_664123_12_s8765\"\n",
"\n",
"cfg.model_file = \"last.pth\"\n",
"\n",
"cfg.model_name = \"ipcfg_test_oneset\"\n",
"# cfg.model_file = \"00000640.pth\"\n",
"\n",
"cfg.model.decoder.NT=6\n",
"cfg.model.decoder.n_set=1"
]
},
{
"cell_type": "markdown",
"id": "55c98afa",
"metadata": {},
"source": [
"### Model B"
]
},
{
"cell_type": "code",
"execution_count": 239,
"id": "cb582574",
"metadata": {},
"outputs": [],
"source": [
"ipcfg_la = False\n",
"if ipcfg_la:\n",
" cfg.model.decoder.name = \"IPCFG2DLADecHead\"\n",
" cfg.model.decoder.nt_la = True\n",
" \n",
" cfg.model_name = \"ipcfg_toyla_664113_s1\"\n",
" \n",
" cfg.model_file = \"00001920.pth\""
]
},
{
"cell_type": "code",
"execution_count": 240,
"id": "e947a42b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading from /mnt/yann/model/ipcfg/ipcfg_test_oneset/last.pth\n",
"Old configs:\n",
"\n",
"alias_root: /mnt/yann/model/ipcfg\n",
"model_root: /mnt/yann/model/ipcfg\n",
"alias_name: ipcfg_test_oneset\n",
"model_name: test\n",
"alias_odir: outs\n",
"model_file: ''\n",
"model_time: null\n",
"blockprint: false\n",
"purge_root: false\n",
"monitor: DalleMonitor\n",
"worker: IPCFG\n",
"verbose: true\n",
"seed: 9876\n",
"eval: false\n",
"rank: 0\n",
"mode: ddp\n",
"autocast: false\n",
"check_unused_param: false\n",
"num_proc: 1\n",
"num_gpus: 1\n",
"port: 1789\n",
"dist_url: tcp://localhost:${port}\n",
"model:\n",
" vq:\n",
" name: PretrainedVQGAN\n",
" load: false\n",
" gumbel: false\n",
" model_root: /mnt/yann/model/vqgan_tuned/logs\n",
" model_name: clevr_vqgan_128e_1024\n",
" model_file: last.ckpt\n",
" model_time: 2022-09-10T05-13-14\n",
" vocab_size: 1024\n",
" embed_size: 256\n",
" max_vis_len: 256\n",
" embedder:\n",
" name: DummyEncHead\n",
" encoder:\n",
" name: DummyEncHead\n",
" z_dim: 0\n",
" decoder:\n",
" name: IPCFG2DDecHead\n",
" NT: 6\n",
" T: 3\n",
" h_dim: 512\n",
" w_dim: 512\n",
" z_dim: ${..encoder.z_dim}\n",
" s_dim: 256\n",
" n_set: 1\n",
" nt_la: false\n",
" drop_1d: false\n",
" rate_1d: 5.0\n",
" beta_1d: 0.1\n",
" drop_2d: false\n",
" rate_2d: 3.0\n",
" beta_2d: 0.1\n",
" mini_1d_ll: ${..loss.mini_1d_ll}\n",
" mini_1d_2d: ${..loss.mini_1d_2d}\n",
" grid_size:\n",
" - 6\n",
" - 6\n",
" loss:\n",
" name: PCFGLossHead\n",
" layers: []\n",
" scaling: false\n",
" mini_1d_ll: false\n",
" mini_1d_2d: false\n",
" kl_decay_ratio: 1.0\n",
" kl_cycle_steps: 250.0\n",
" kl_activation: cosine\n",
" kl_max_beta: 1.0\n",
" kl_cycle: false\n",
" kl_max: null\n",
" bh_beta: 0\n",
" sh_beta: 1\n",
" th_beta: 0\n",
"optimizer:\n",
" use_lars: false\n",
" name: ${.opt_name}\n",
" warmup: false\n",
" warmup_steps: 250.0\n",
" warmup_times: ${.warmup_steps}\n",
" warmup_epoch: 10\n",
" warmup_fn: linear\n",
" decay_step: 1000.0\n",
" decay_rate: 1.0\n",
" lr: 0.01\n",
" weight_decay: 1.0e-07\n",
" betas:\n",
" - 0.9\n",
" - 0.75\n",
" max_gnorm: 3\n",
" lr_weight: 0.2\n",
" lr_bias: 0.0048\n",
" batch_size: ${running.batch_size}\n",
" epochs: ${running.epochs}\n",
" steps: []\n",
" gamma: 0.5\n",
" batch_sch: true\n",
" opt_name: AdamW\n",
" sch_name: WarmupExpDecayLR\n",
" optimizer:\n",
" - ${optimizer.opt_name}\n",
" - lr: ${optimizer.lr}\n",
" betas: ${optimizer.betas}\n",
" scheduler:\n",
" - ${optimizer.sch_name}\n",
" - step_size: ${optimizer.decay_step}\n",
" gamma: ${optimizer.decay_rate}\n",
" warmup_step: ${optimizer.warmup_steps}\n",
" warmup_fn: ${optimizer.warmup_fn}\n",
"data:\n",
" name: CLEVR\n",
" version: 0.0.1\n",
" num_proc: ${num_proc}\n",
" data_root: /mnt/yann/data/dallemini/2022-09-10T05-13-14_clevr_vqgan_128e_1024/last/CAPTION_4.0_25_captions\n",
" more_root: /mnt/yann/data/CLEVR_v1.0_320x320_bbox\n",
" dump_root: ${.data_root}/${model_time}_{model_name}\n",
" data_name: train/{001..007}.parquet\n",
" eval_name: val/{001..002}.parquet\n",
" test_name: ''\n",
" data_seed: null\n",
" train_samples:\n",
" - 0\n",
" - 1024\n",
" eval_samples:\n",
" - 0\n",
" - 512\n",
" test_samples: null\n",
" batch_size: ${running.batch_size}\n",
" crop_size:\n",
" - 0\n",
" - 0\n",
" - 256\n",
" - 256\n",
" resize_size: 256\n",
" max_obj_num: 1\n",
" min_obj_num: 1\n",
" max_txt_len: 64\n",
" min_txt_num: 2\n",
" max_txt_num: 100\n",
" min_pixel_num: 2\n",
" min_train_vid: 0\n",
" max_train_vid: 9999\n",
" train_divider: 10000\n",
" min_eval_vid: 0\n",
" max_eval_vid: 999\n",
" eval_divider: 2000\n",
" txt_vocab_name: txt_vocab_is_null\n",
" vis_vocab_name: vis_vocab_is_null\n",
" vis_vocab_size: 3\n",
" txt_special_token:\n",
" bos: <TXT|BOS>\n",
" pad: <PAD>\n",
" unk: <UNK>\n",
" vis_special_token:\n",
" bos: <VIS|BOS>\n",
" unk: <UNK>\n",
" txt_only: false\n",
" vis_only: false\n",
" use_preencoded_pandas: true\n",
" require_txt: false\n",
" rot90_image: 0.0\n",
" vis_fake_hw:\n",
" - 6\n",
" - 6\n",
" obj_fake_hw:\n",
" - - 4\n",
" - 1\n",
" - - 1\n",
" - 3\n",
" obj_delt_hw: []\n",
"running:\n",
" peep_rate: 32\n",
" topk_best: 3\n",
" save_rate: 10000000000.0\n",
" skip_save: false\n",
" save_last: false\n",
" nsampling: 1\n",
" infer_mode: false\n",
" batch_size: 8\n",
" epochs: 30\n",
" save_epoch: true\n",
" optim_rate: 1\n",
" v_peep_time: 1\n",
" v_peep_topk: -1\n",
"\n"
]
}
],
"source": [
"model = build_main_model(cfg, echo)\n",
"model.build(**{\n",
" \"encoder_vocab\": encoder_vocab,\n",
" \"decoder_vocab\": decoder_vocab,\n",
"})\n",
"\n",
"device = next(model.decoder_head.parameters()).device"
]
},
{
"cell_type": "markdown",
"id": "6e6bbbcc",
"metadata": {},
"source": [
"### Plots"
]
},
{
"cell_type": "code",
"execution_count": 241,
"id": "e787ddee",
"metadata": {
"code_folding": [
0
],
"scrolled": true
},
"outputs": [],
"source": [
"def show_results(vidx=2):\n",
"# vidx = 2\n",
"\n",
" gold = batch_dict[\"bbox\"][vidx]\n",
" pred = argmax[vidx]\n",
"\n",
" pred = yxyx_to_xyxy(pred)\n",
" pred = sorted(pred, key=lambda x: -x[0])\n",
"\n",
" gold_ = gold # None\n",
" pred_ = pred\n",
" pred_ = sorted(pred_, key=lambda x: -(x[2] - x[0]) * (x[3] - x[1]))\n",
"\n",
" visualize_2d_parse(\n",
" [], gold=gold_, area=(0, 0, 16, 16), scale=1.0, figsize=(5, 5),\n",
" rotate=False, new_size=None, convas_size=convas_size\n",
" )\n",
" \n",
" overlap_dict = visualize_partial_parse_toy(gold, pred, area=(0, 0, 16, 16), convas_size=convas_size)\n",
" print()\n",
" for k, v in overlap_dict.items():\n",
" print(k, len(v))\n",
" v = \"\\n\".join(f\"{a}\\t{b:.3f}\" for a, b in v)\n",
" print(v)"
]
},
{
"cell_type": "code",
"execution_count": 247,
"id": "c16a80fb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Biased by row 0.171 (6.000) col 0.514 (18.000) other 0.314 (11.000) # sample 4\n"
]
}
],
"source": [
"batch_dict, argmax, pcfgs = infer_batch(model, device, kbatch=0)\n",
"bias_by_width(argmax)"
]
},
{
"cell_type": "code",
"execution_count": 250,
"id": "aa5f9b0f",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1, 0, 4, 1) 5\n",
"(3, 2, 4, 6) 5\n",
"\n",
"(1, 0, 4, 1) 5\n",
"[1, 0, 2, 2]\t0.400\n",
"[0, 0, 6, 2]\t0.400\n",
"[0, 0, 4, 2]\t0.545\n",
"[2, 0, 4, 2]\t0.571\n",
"[2, 0, 4, 1]\t0.800\n",
"(3, 2, 4, 6) 5\n",
"[0, 2, 5, 6]\t0.333\n",
"[0, 2, 4, 6]\t0.400\n",
"[3, 2, 4, 4]\t0.667\n",
"[3, 4, 4, 6]\t0.667\n",
"[3, 2, 4, 6]\t1.000\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAR4AAAE2CAYAAAC3Nv2NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAARKElEQVR4nO3cfYwtd13H8c9nHs6tpIWqNTal2PqHxig+xGukNWJQI1IioTESYysNFUhrL/UBlEAjSpSAEtSYUHsbYmiVxOIDNSAgkcQQCeWPXo2JRDFVS1qlVPoAfdI9M/PzjzNzdvbcu3tmd+9+z8zt+5Wc3HN2zs5+Zs7MZ36/2c11SkkAECnbdAAAzz4UD4BwFA+AcBQPgHAUD4BwFA+AcMVeC0+dOjW537XfcMMNm46Akbv99ts3HeFZ4fjx495t2Z7F0+FvfXAu4Xg+WvaufbO0tnhSSkopqWma0X9g/Q2+7bbbJpE5yzJl2WLG2zSNmqbZcKL1ppr5pptukjSNzLaVZdnymK7revTHcpe5e76XQSOepmm0tbU1+g+rKLY3p65rbW1tjf7DKstSZVnKtuq61nw+H3Vm2yqKQrPZTJJUVZWqqhp95rIsl6/n87mqqtpgovVsazabqSgKpZSW+3nMsizTbDZTnudr3zt4qjWF0UM3OrO9fD72zP0y7/bzmPX3aX80PGarx8EUjgtJO/bzlM6/IfitFoBwFA+AcBQPgHAUD4BwFA+AcBQPgHAUD4BwFA+AcBQPgHAUD4BwFA+AcBQPgHAUD4BwFA+AcBQPgHAUD4BwFA+AcBQPgHAUD4BwFA+AcBQPgHAUD4BwFA+AcBQPgHAUD4BwFA+AcBQPgHAUD4BwFA+AcBQPgHAUD4BwFA+AcBQPgHAUD4BwFA+AcBQPgHAUD4BwFA+AcBQPgHDFkDfZVlEUSikddZ5DybJMtpfP8zzfcKL1+hmzLFNRDPpINqrLbHsymbNs+xo7heNC2s5sW3meL4/tseqOhyEGHTF5nivLstEXT/+DOXHihCRNKrM0/rzS9DMXRTH68rG9zGxbZVmOfj/3M68z+FK1n5WOydQyTy2vNL3MUzyWp5h5L4OKJ6Wkuq5H37hZlunkyZOSpKZp1DTNJDJ3V9+maVTX9YYTrdeNgKVpZO6mAF3muq7VNM2GU+1tqpmHTgkHFU/TNNra2hr9SdydELYnk7koimXmuq61tbW16UhrlWWp2WwmSaqqSvP5fMOJ1pvNZsvbBVVVqaqqTUfak+1l5qZpNJ/PR1/w3XE8ZBo7eMTTPcYupSTbk8nczzeFvNJ25qnsY2na+7l7PvbM+8nIr9MBhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeDbM9qYjAOGKIW/KskxlWappmqPOcyh5ni9P5CzLVBSFUkobTrU728qybJk5z3OVZTn6zHmen/Z8Kpkl7Xg+Vt2x0T0vimL0F6ksy5aZ19lX8UxB/yQeuhM2qX8wTSnzjTfeuOkYB3by5EkVxaBDf6O6Y6MrnillXmfwloy9bc+EzNjN1Pbz1PKuM6h4mqZRXdejHk5Li5FZN92aSuY8z5dD/7qu1TTNJDJ3br31VtV1vcE063XTlhMnTkiSqqoa/W2DbnqYZZlSSstjY8z6mdcZXDzz+Xz0G14UxfKeSV3X2traGv1JXJblMnPTNKPPvHrlreta8/l89Jlns9nydZd5zGzr2LFjy+KpqkpVVW061p72c49n8A2FMR9YU9bfryml0e/n1YxTzTwFU8+8l/HfyQRwzqF4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AISjeACEo3gAhKN4AIQbXDy2jzLHs9bqfp3Cfp5Cxj7bOzJPJf8UMw9VDHlTlmUqikIppaPOcyh5ni8/oDzPVZbl5DJLGnVm28uc0nb+sWfOsu1rbD//WPUzd/t87OWzup/3Mrh4ZrPZoUJFm2LmPM8ncVL0FcWgQ2hUiqKYVG7bKsty0zHOqn3t/TFf1aQzD0fJfPatZp5aXonMR2E/I7JBxdM0jaqqGv2Gd1NC26rrWnVdjz5znufLq++UMneqqlJd1xtMs97q9HAqmYuiUJZlSimpqio1TbPpWHvqZ15nX8Uz9g0vimI5F26aRvP5fPQnsbR9n6Su60lk7l/ZusxjtnolrqpKVVVtKM0w3Q3xrnjquh595izLlo+17x260rGfDFO1ul+nsJ+nkLFvannP5FzYhj7+jgdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igdAOIoHQLjBxWP7KHM8a/X3q+3R7+fVjFPNPHZTzCwNz1kMeVOWZSrLUimlQ4U6almWLTc8z3PNZrPRZ87zfJk5y7LRZ7atLNu+XuV5vsE0w9jekbMoitGfyP39bFtFUezY72O0n4vQ4OIZ+0avmmLmPM8ncSL3kfnodcVzLhm8NWO+Cvf1G5fMR2P1qja1zFPIK00/814GFU/TNJrP56Pf+DzPl8Pouq5VVdUkMpdlKUmqqkp1XY868+q0paoqVVW1wUTrrY4Yuv08Zl3mPM+VUppE5izLBk8JBxfPFE5iafueyZQyd2XZFfyUdAU/dv0r8RQyd/d4plY8Q29xTOsmCIBzAsUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyBcMeRNtmX7qLOcFV3OqWTuZyTz0VjNSOajsZ+Mg4onyzLNZjOllA4V7KhlWbbc8ClmzvNcx44dG3Vm28qy7YFyURSTOCFWM/dfj1E/c5ZlKstSeZ5vONXeVvfzXgaPeIpi0FtHI8uy0R9cq8gcI8/z0Z/Eq6aYeS+D2ySlNOorsXT6UK9pmg2mGaafeQr7WNo51SLz0ekX+hQyn/WpVl3Xqqpq9Cdynucqy1K2Vde15vP56D+soiiW05VuP485s+0dV96qqlTX9egz90fsVVWpqqoNJlrP9nJ6lVLSfD5XXdebjrWnbko4ZAQ8qHhSSqM/IaTFh5VSkm01TTP6E0LaeVI0TTP6E0LSaaPKKWTunwx1XY/+JO4KviueKWROKSnP80HFM63JOYBzAsUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyAcxQMgHMUDIBzFAyBcMeRNtpVlmZqmOeo8h2JbtpfPx565y7j6OqW0wVR76+/j7vXUMmdZNvrMWZZN6liWTt/PexlUPFmWaTabjfqDknZ+WHmey/bkMo/9hJC0oyxvvvnmDSbZn24/F0WxYxvGqH9Rsq2yLJXn+YZT7W31QrqXtcXTPymmxPbkMo/9ZDiToVe4MelGPFMypcxDjgnvdXU9derUuC+9AEbr+PHjuzbQnsUDAEdhGmM3AOcUigdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igdAOIoHQDiKB0A4igeTYvuk7bdtOgcOh+I5JNvX2L7X9pO2v2j747Z/sF32dtsf2OX77rf9TPt93eO9tq+w/ZTt88/wPf9o+w22L7edet93v+23HPW2thl23aYIKaUbU0q/tamfj7Nj0P/HgzOz/UZJb5F0o6RPSNqS9DJJr5T06QGreEVK6ZNnWO+Dkn5K0h29r71Q0rdL+lNJF7RfvjClVNn+Pkmfsn0qpfS3B98iIAYjngOy/TxJvynpRErpQymlp1JK85TSR1JKv3rI1d8p6bqVr10n6WMppUdW35xSulfS5yR9z5CV236N7U/bfo/tx2z/p+2ressvsf1h24/avs/269uvv0zSLZJ+uh1p/dOan3O97X+x/YTt/7B9Q2/ZS2w/aPtNth9uR4vXD8h+h+13rKzjzb11XG375bb/rc1/S+97v9/2PbYfb9/7Xtuz3vKX2v687a/Y/kPbn7L9ut7yn2u35zHbn7B92ZD9jdNRPAd3paTzJN19BOv+E0k/ZPsFkmQ7k3SNFoV0GttXSHqhpPt6X3u8m/Lt4kWSPi/pIknvlvRH3v4fnO6S9KCkS7QYeb3T9o+klP5G0jslfTCldH5K6bvXbMfDkn5C0nMlXS/p921/b2/5xZKeJ+n5kl4r6VbbX7tmnasu1uJzeL6kX5f0Pkk/K+m4pBdLepvtb27fW0v65Xabr5T0o5JukiTbF0n6C0lvlfT1WuybH+h+iO1XalG6PynpGyT9vRajTxxESonHAR6SrpX00Jr3vF3SB3ZZdr+kJyU93nu8vrf8k5JuaZ//mKT/kVS2ry+XlNrveaZ9/h61/7/SgOyvkXRf7/Vz2nVcLOkFWpygF/SWv0vSHeu2acDP/StJv9g+f0mbvegtf1jSFWvWcYekd6ysI29fX9Bux4t67z8l6epd1vVLku5un18n6Z7eMkt6QNLr2tcfl/Ta3vJM0tOSLtv0sTjFByOeg3tE0kW2D3Of7OqU0oW9x/t6y+6U9Or2+asl3ZVSmq98/0WSzpf0Ji1OwnIfP/uh7klK6en26flajHIeTSk90XvvF7QYUeyL7atsf7ad8jwu6eVt5s4jKaWq9/rpNsN+PJJSqtvnz7T/fqm3/Jlunba/1fZf237I9le1GL11eS7RomgkSWnRLg/21nOZpD9oR5KPS3pUi3La934BU63DuEfS/0m6+ojW/yFJl9r+YS2G92ecZqWU6pTS70n6X7XThkP6b0lfZ/uC3te+SdJ/dT9yyEpsH5P0l1qMxL4xpXShpI9pcbJuym2S/lXSt6SUnqvF1KnL80VJl3ZvbKedl/a+9wFJN6xcKL4mpfSZoOznFIrngFJKX9HinsKt7Q3N59gu26v8u3tvzWyf13scG7j+p7S45/B+SV9IixvIe/ltSW+2fd5Btqf3cx+Q9BlJ72rzfpcW91+6X6F/SdLl7X2nvcwkHdNiili1N69fephsZ8EFkr4q6Unb3ybp53vLPirpO9vPspB0QoupZ+ekpLfa/g5p8csF268Kyn3OoXgOIaX0u5LeKOnXtDjBHpD0Bi3uZXR+Rovhfvf4996yj3jn3/Gs3qi+U4sh/h8PiPNRSY9J6n4D9aTtF+97o7YzX67F6OduSb+Rtn/t/+ftv4/Y/ofdVtBO1X5B0p+1ua6R9OED5jlbfqXN8YQWN6E/2C1IKX1Z0qu0uNH+iBZ/unCvFqNapZTulvQ7ku5qp2n/LOkq4UD4z96BM2hHdA9Kujal9HebznOuYcQDtGz/uO0L2+lwd//nsxuOdU6ieHAoK1PF/uOg0zzZ/twu67z2bGY/gyu1mAp/WdIrtPit4zN7fwsOgqkWgHCMeACEo3gAhKN4AISjeACEo3gAhKN4AIT7f9D3An8T9A2yAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 360x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAR4AAAE2CAYAAAC3Nv2NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAQZUlEQVR4nO3cb4xld13H8c/3d865u5FuWtsasoi0T4yJQSXZXSiRav2XUBNDNSRGEhoNISVtfYJRGzUCJiIhYmIiTZtKsEoiWEMbCCjoA41I2aWDmCiGSLWkZZXYUnRpu8w953x9cOfcOXs7c+dMd+Z7fmf6fiUbZnfu3P3cO/e877l3lpq7CwAipbEHAHjxITwAwhEeAOEID4BwhAdAOMIDIFy57pMbGxuT+1n7bbfdNvYEZO7ee+8de8KLwqlTp2y3z60NT+c1n/+smmuePbhFh+ni02MveEFuaP9Dx5uxVxx9Fwvp9OnvHHvGkXfyZKPz53f//J7hcXc11zyrz73iJuX+jw3NTHccf0AbGxs6d+6c2rbNfnNKSSklnf7CGZ171Vm1bTv2pD11myWpbdvJbL799tt14sKGzp59MvvNZqaUkswWJw1N02T/WO5vPnPm6rWXHXTGI0mbm5vZf7PKcvvmNE2jzc3N7L9ZVVWpqipJi83z+TzrzWamsiw1m80kSXVdq67r7Dd397Ekzedz1XU94qK9mZlms5nKspS7L+/nnKWUNJvNVBTFnpcdHJ4pnD24+3Jj93Hum/sxd/fs477TfTylzTv9Plf9+3lqx99e+KkWgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMKVgy9YlnL3w9xy2VJKMrPlx0VRjLxob/2NKSWV5eBvyWi6zWY2mc0pbT/HTuFxIW1vNjMVRbF8bOeqezwMMfgRM5vNsg9P/xtzxx13SNJkNp+4IF04cWf2eyU97wCY2uayLLOPj5ktN5uZqqrK/n7ub97L4PDs50pzMpXNFwvpxIWNsWcceReLaT6Wp7h5ncHhqes6++KmlHTPPffozJkzuvvuu9W27SQ2F0WhM6fP6Oy5s2qaZuxJeyqKYnlK3bZt9pu7lwApJZ05fUbN2UZt2449a63+ZklqmmlsHvqScHB4Njc3sz+IVw+IKWwuy/KSB9fm5ubIi/ZWVZVms5mkxRPSfD4fedHeZrPZ8n6u61p1XY+8aD0zW25u21bz+Tz7wHfvsQ55GTs4PO6e/UEsbb/f0O3NfXN/3xT2StO7j6Vp38/dx7lv3s9GfpwOIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QjPyMxs7AlAuHLoBauqUtu2h7nlshVFsTyQU0oqy1LuPvKq3ZmZUkrLzUVRqKqq7DcXRfG8j6eyWdIlH+eqe2x0H5dlmf2TVEppuXkv+wrPFPQP4qF3wpj6D6Ypbi7LcjIHcqeqKpXl4If+aLrNXXimtHnPy617ptrY2HB31w2PPKzmmmcPaht28tPvkV7yzbFXHH3/d7V05efHXnHknTzZ6Pz57921QoPC07atmqbJ+nRaWpzqdS+3prK5KIrlGUPTNGrbdnKbm6YZedF63cuWbnNd19m/bdC9PEwpyd2Xj42c9TebmU6dOrVreAadu7Vtq/l8nv0NL8tyeaObptHm5mb2B3FVVcvNbdtmv7n/UlZahGc+n2e/eTabqSiK5UE8n8/HnrWWmenYsWPL8NR1rbqux5611n7e4xn8hkLOD6wp69+v7p79/dzf2H08pc3d76dg6pvXyf+dTABHDuEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhBofHzA5zx4vW6v06hft5Chv7zOySzVPZP8XNQ5VDLpRSUlmWcvfD3nNZiqJYfoOKolBVVZPbLCnrzWa23Nl9bGbZb04pLT/u9udsp825x6e/eS+DwzObzS5rVLQpbi6KYhIHRV9ZDnoIZaUsy0ntNjNVVTX2jAO1r3s/52c1aefTUTYfvNXNU9srsfkw7OeMbFB42rZVXdfZ3/DuJaGZqWkaNU2T/eaiKJbPvlPcXNe1mqYZedF63UuV7mxyKpvLslRKSe6uuq7Vtu3Ys9bqb97LvsKT+w0vy3L5WrhtW83n8+wPYmn7fZ6maSaxuXtmc/fl5pz130PrDuK6rkdetV73hngXnqZpst+cUlr+2vOyQ68094Nhqlbv1yncz1PY2De1vTs5Crehj3/HAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEGh8fMDnPHi1b/fjWz7O/n/sbu4ylt7n6fuylulobvLIdcKKWkqqrk7pc16rCllJY3vCgKzWaz7DcXRbHcnFLKfrOZKaXt56uiKEZcM4yZLXeamcqyzP5A7t/P3eb+/Z6j/TwJDQ5P7jd61RQ3F0UxiQO5j82HrwvPUTL41uT8LNzXLy6bD8fqs9rUNk9hrzT9zesMCk/btprP59nf+KIolqfRTdOorutJbK6qSpJU17Wapsl6c/eypXsGrutadV2PvGq97oyhKAq5+/J+ztkUN6eUBr8kHByeKRzE0vZ7JlPa3MWyC/wUlGUpd18GPnf993mmsLl7j2dq4Rn6Fse03gQBcCQQHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhXDrmQmcnMDnvLgeh2TmVzf+PUNnd7c9+8upHNh2M/GweFJ6Wk2Wwmd7+sYYctpbS84VPcXBSFjh07lvVmM1NK2yfKZVlO4oDoNpuZyrK85DbkqL85paSqqlQUxcir1lt9bKwz+IynLAddNBsppewfXKvYHKMoiuwP4lVT3LzO4Jq4e9bPxNLzT/Xath1xzTD9zVO4jyU2R+kHfQqbD/ylVtM0qus6+wO5KApVVSUzU9M0ms/n2X+zyrJcvlzp7uecN5vZ8n6WpLqu1TRN9pu7+9ndVde16roee9ZaZrZ8eeXums/nappm7FlrdS8Jh5wBDwpP983K+cElLb5Z7i4zU9u22R8Q0qUvY9u2zf6AkHTJmcNUNvcPhqZpsj+Iu8B34ZnCZndXURSDwjOtF+cAjgTCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOHKIRcyM6WU1LbtYe+5LGYmM1t+nPvmbuPq7919xFXrrd7HU9ssSSml7DenlCb1WJaefz+vMyg8KSXNZrOsv1HSpd+soihkZpPbnPsBIemSWJZlecnvc7Qay6ls7jaamaqqUlEUI69ab/WJdJ09w9M/KKbEzCa3OfeDYSfd2cOUsPlwDTnrGfxSCwAOiuV+Wg/g6JnGuRuAI4XwAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gwWbbwQTN72szOmdmNZvbl3ucfM7OfHHMjdkZ4RmRmbzKzR8zsW2b2X2b2V2b2uq3PvdPMPrTL1z1mZs9tfV3364/M7AYze8bMrtjha/7JzO40s+vNzHtf95iZ3XXYt3Vrw6636QV6naSfkvRyd3+1u/+Du3/fAV4/Dsmg/3c6Dp6ZvV3SXZLeJulTkjYlvV7SGyR9ZsBV/Iy7/+0O1/uEpDdK+pPen71S0vdL+nNJJ7b++Cp3r83stKS/N7MNd/+bF36LRnGdpMfc/Zmxh2B/OOMZgZldKel3JN3h7h9192fcfe7uH3f3X73Mq79f0q0rf3arpE+6+1OrF3b3RyT9q6RXDblyM/tFM/uMmf3+1kuc/zSzm3uff5mZfczMvmFmXzGzt279+esl/Yakn9860/rnXa7/LjN71MwumNmXzOxnd7ncWyT9saTXbl3fu8zspq3w7nT51Lvup8zsL8zs6iG3GQeP8IzjtZKOS3rwEK77zyT9iJl9j7Q44CS9SYsgPY+Z3SDplZK+0vuzb3Yv+XbxGklflnStpPdK+oBt/0ebPizpCUkv0+LM691m9uPu/teS3i3pI+5+hbv/0C7X/aikGyVdKeldkj5kZidXL+TuH9DibPHhret7x5q9kvTLkm6R9KNb256W9P49vgaHhPCM4xpJT7p7fRnX8dBWILpfb5Ukd39c0t9JevPW5X5C0jFJn1j5+ifN7DlJD0u6W9JD3Sfc/Sp3X/dy76vufp+7N1oE7aSkl27F7ocl/bq7X3T3L2pxVrJ6BrYrd3/A3c+7e+vuH5H075JePfTr13ibpN909yfc/duS3inpjWbG2w0jIDzjeErStZf5oL9lKxDdr/t6n7tf2+F5s6QPu/t85euvlXSFpF+RdJOkah9/9393H7j7s1sfXqHFmcQ33P1C77JflfTdQ6/YzG41sy92QdXibOzafWzbzXWSHuxd779JaiS99ACuG/tEeMbxsKRva3Hqfxg+KunlZvZjkn5Ou7zMcvfG3f9A0kVJtx/A33te0tVmdqL3Z6+Q9LXur1z3xWZ2naT7JN0p6Rp3v0rSv0g6iP/27uOSbl6J9XF3/9qeX4kDR3hG4O7/K+m3Jb3fzG4xs+8ws8rMbjaz9/YumszseO/XsYHX/4ykv5T0QS1eFj2yx5e8R9KvmdnxF3J7en/v45I+K+n3tvb+oKS3SOp+hP51Sddvve+0k5doEaf/kSQz+yUtzngOwj2SfncrbjKz7zKzNxzQdWOfCM9I3P19kt4u6be0ONAe1+KZ/qHexX5B0nO9X4/2PvfxlX/Hs/pG9f1avLz40wFzPqHFm63dT6C+ZWY37vtGbW++XouznwclvaP3Y/8Htv73KTP7wuoXuvuXJL1PizPCr0v6AUn/+AJ3rPpDSR+T9GkzuyDpc1q8SY4R8B97BxCOMx4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOH+H4vna0rsUOaeAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 360x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAR4AAAE2CAYAAAC3Nv2NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAQCklEQVR4nO3db6hk913H8c/3d86ZXTSxoY2UW6vJIxWpfyDdZouNxn/QPJBGKYiFBqUUFrd5UlGDim0FaylWEKwEYqmpBVsrTUip9d8DxZpNJasVtFJsMCVpa7BplE1MMnPO+fpg7pl77uzeuWd27/2e37l5v+DSu3vnzn5m7pz3nJkNW3N3AUCkNPYAAC89hAdAOMIDIBzhARCO8AAIR3gAhCs3ffHixYv8XTuAq3LLLbfYQV/bGJ4O/60PgKHMDuzNyqHhcXe5u9q2zT5AZqaiKCRpMptTSkpp+Yq3bVu1bTvyosOx+fiZmVJKq4O4aZrsH8vd5u7zTQad8bRtq/l8nv0PqyzL1Q+raRrN5/Psf1hVVamqqtXmxWKR9WYzU1mWms1mkqS6rlXXdfabq6pSSknursViobqux561kZlpNpupLEu5++p+zllKSbPZbPXkv8ngl1pTOHvozs7MbPV57pv7Me/u55z179P+2XDO1h8HU3hcSNp3P0/p+BuCv9UCEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQLhyyIXMTGVZyt2Pe881SSnJzFafF0Ux8qLD9TemlFSWg34ko+o2m9lkNqe09xw7hceFtLfZzFQUxeqxnavu8TDEoEdMURRKKWUfHjNb/XCmvDl3/QOgLMtJHMjd5u5JNPfN/ceFmamqqkk9lg8z+KlqmyvNxdQ2T22vxOYoU9y8yaDwuLuapsm+uP2XV23bqm3byW1ummbkRYfrn5lNYXP3EqDb3DSN2rYdedVmU9089CXhoPC0bav5fJ79QdwdEGY2mc1lWa42N02j+Xw+9qRDVVWl2WwmSarrWovFYuRFh5vNZquX3nVdq67rsSdtZGarzW3barFYZB/47nE85GXs4DOe7iN37i4zm8zm/r4p7JX2Nk/lPpamfT93n+e+eZuN+b+TCeDEITwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4RmZmY09AQhXDrlQSklVValt2+Pec02KolgdyCkllWUpdx951cHMTCml1eaiKFRVVfabi6K47POpbJa07/NcdY+N7vOyLLN/kkoprTYfZqvwTEH/IB56J4yp/2Ca0uZz586NPeOq3XvvvSrLQQ/9UXWPjS48U9p8mEG35OzZF9U0O9c0CCdNqbPfeVGnVY89ZCsvqNSZM2fGnrGVnZ0dPfTQQ2PPOFKDwtM0O3r44aeyPp2Wlmdm3cuttm3VNE32m4uiWJ36N02jtm0nsfn8+VqnL9V630ceUdM0Y0/aqHvZcv78eV1/6aIuXLiQ/dsG3cvDlJLOnDmjuq4ntfkwg8/dFotF9je8LMvVeyZN02g+n2d/EFdVtdrctm32m9dPpZum0WKxyH7zbDZb/brbnDMz06lTp1YHcV3Xquu8zy63eY9n8BsKOT+wpqx/v7p79vfz+sapbp6CqW/eJP93MgGcOIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEGxweMzvOHS9Z6/frFO7nKWzsM7N9m6eyf4qbhyoHX7As5e7HueWaFUWx+gEVRaGqqia3WVLWm81stVPa25/75pT2nmP7+3N1pc25x2d98yaDwzObza560BhSSpPbXBTFJA6KvrIc/BDKRlmWk9tdVdXYE47UVvd+zs9q0pVPR9l89NY3T22vxObjsM0Z2eDwzOfz7G94SkllWcrM1DSNmqbJfnNRFKtn3ylt7tR1raZpRlxzuPWXh1PZXJbl6qXLYrFQ27Yjr9psffMmg8NT13X2N7wsy9Vr4bZttVgssj+Ipb3X703TTGJz/5mt25yz9Wfiuq5V1/VIa4bp3hDvDuKmabLfnFJafRx62aFXmvvBMFXr9+sU7ucpbOyb2t4rOQm3oY//jgdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCDQ6PmR3njpes/v1qZtnfz+sbp7o5d1PcLA3fWQ69wqqq5O5XPShCSml1w4ui0Gw2y35zURSrzSml7DebmVLae74qimLENcOY2b6dZVlmfyCv389lWe77dY62eRLaKjxTklLK/ge1riiKSRzIfWyOUZaDD9VJGHxrcn4W7usXl83Ho7/3zGvPjLhkS9dLt1+axn0sTe9xIR3DS635fJ79jS+KYnUa3TSN6rqexObubLKuazVNk/Xm9ZctFx65oLquR1x0ODNTWZY6f+m8pIur+zln3ebuvp7C5pTS4JeEg8MzhYNY2nvPpG3byWzuYtm2rRaLxdhzttIFPnf9Z+IpbO7e45laeIa+xTGtN0EAnAiEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMKVQy9oZse548h0O81sEpv7G9l8PNY3svl4bLNxcHhms5nc/apHRUgprW54Smlym4ui0KlTp7LebGZKae9EuSzLSRwQ65v7v87R+uaqqlQUxYiLDre+eZPB4SnLwRfNQkop+wfXOjbHKIoi+4N43RQ3bzK4Ju6e9TOxdPmpXtu2I64Zpr95CvextP+lFpuPTz/oU9h8LC+15vN59gdyURSqqkpmpqZptFgssv9hlWW5ernSNI3qus56s5nte+at61pN02S/uX/GXte16roecdHhzGzfy6vFYqGmaUZetVlKSVVVDToDHhye3A8IafnDcneZmdq2zf6AkPYfFG3bZn9ASLrsrHIKm/sHQ9M02R/EXeC78Exhs7urKIpB4ZnWi3MAJwLhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwhEeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA+AcIQHQDjCAyAc4QEQjvAACEd4AIQjPADCER4A4QgPgHDl0AumlNS27XFuuWZmJjNbfZ775m7j+q/dfcRVm/Xv4+7XU9ucUsp+c0rpss05P5aly+/nTQaHZzabZf2Dkvb/sIqikJlNbnPuB4SkfbG8++67R1yyne5+Lsty323I0fqTUlVVKopixEWHW9+8yaHhMTMVxdd09uzONQ/DSVLq7HdL3/LsP409ZCsvJOnWW28de8ZWdnZ2VmdpUzDkrGfQGc/nPnda0jPXugcnygfHHoAJs9xP6wGcPNM4dwNwohAeAOEID4BwhAdAOMIDIBzhARCO8AAIR3gAhCM8AMIRHgDhCA8my5Y+bGbPmNk/mtltZvbF3tcfN7MfH3MjrozwjMjM3mJmj5rZs2b2NTP7jJm9Yfdr7zazjx7wfY+b2fO739d9/L6ZnTWz58zsuit8zz+b2TvM7GYz8973PW5m9xz3bd3dcOBtukpvkPQTkl7t7q9z97939+86wuvHMRn87/HgaJnZOyXdI+mcpL+UNJf0RklvkvTZAVfxk+7+N1e43iclvVnSH/V+7zWSvkfSn0i6fve3b3D32sxeK+nvzOyiu//11d+iUdwk6XF3f27sIdgOZzwjMLOXSfpNSefd/ZPu/py7L9z9U+7+S9d49fdLumvt9+6S9Ofu/vT6hd39UUn/JukHhly5mf2cmX3WzH5n9yXOf5rZHb2vv8rMHjKzb5jZl8zs7bu//0ZJvyrpZ3bPtP7lgOu/x8weM7NLZvYFM/upAy73Nkl/KOn1u9f3HjO7fTe8V7p86l3302b2p2b28iG3GUeP8Izj9ZJOS3rgGK77jyX9kJl9u7Q84CS9RcsgXcbMzkp6jaQv9X7vf7qXfAe4VdIXJd0o6f2SPmR7//rTxyQ9KelVWp55vdfMftTd/0LSeyV93N2vc/fvP+C6H5N0m6SXSXqPpI+a2WX/Cp27f0jLs8ULu9f3rg17JeluSXdK+uHdbc+If1RoNIRnHK+Q9HV3r6/hOh7cDUT38XZJcvcnJP2tpLfuXu7HJJ2S9Om17/+6mT0v6YKkP5D0YPcFd7/B3Te93Puyu9/n7o2WQduR9Mrd2P2gpF9x9xfc/fNanpWsn4EdyN0/4e5fdffW3T8u6T8kvW7o929wTtKvufuT7v6ipHdLerOZ8XbDCAjPOJ6WdOM1Pujv3A1E93Ff72v3ay88b5X0MXdfrH3/jZKuk/SLkm6XVG3xZ/9X94m7/9/up9dpeSbxDXe/1LvslyV929ArNrO7zOzzXVC1PBu7cYttB7lJ0gO96/13SY2kVx7BdWNLhGccFyS9qOWp/3H4pKRXm9mPSPppHfAyy90bd/9dSS9I+oUj+HO/KunlZnZ97/e+Q9JXuj9y0zeb2U2S7pP0DkmvcPcbJP2rpGH/1wWbPSHpjrVYn3b3rxz6nThyhGcE7v6/kn5D0gfN7E4z+yYzq8zsDjN7f++iycxO9z5ODbz+5yT9maQPa/my6NFDvuV9kn7ZzE5fze3p/blPSHpY0m/v7v0+SW+T1P0V+lOSbt593+lKvlnLOP23JJnZz2t5xnMU7pX0W7txk5l9q5m96YiuG1siPCNx9w9IeqekX9fyQHtCy2f6B3sX+1lJz/c+Hut97VNr/x3P+hvV92v58uIjA+Z8Wss3W7u/gXrWzG7b+kbtbb5Zy7OfByS9q/fX/p/Y/d+nzeyy/3sKd/+CpA9oeUb4lKTvlfQPV7lj3e9JekjSX5nZJUmPaPkmOUbAP/YOIBxnPADCER4A4QgPgHCEB0A4wgMgHOEBEI7wAAhHeACEIzwAwv0/QQEZOjirj98AAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 360x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"vidx = 2\n",
"show_results(vidx=vidx)"
]
},
{
"cell_type": "markdown",
"id": "ed338a97",
"metadata": {},
"source": [
"## pred"
]
},
{
"cell_type": "markdown",
"id": "15bc44f5",
"metadata": {},
"source": [
"### PCFGs"
]
},
{
"cell_type": "code",
"execution_count": 251,
"id": "52a4aa3a",
"metadata": {},
"outputs": [],
"source": [
"rule_prob, root_prob, term_prob = pcfgs"
]
},
{
"cell_type": "markdown",
"id": "1c938386",
"metadata": {},
"source": [
"#### Binary rules"
]
},
{
"cell_type": "code",
"execution_count": 252,
"id": "18957cc9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([4, 6, 2, 9, 9]) torch.Size([4, 6])\n"
]
}
],
"source": [
"lr = rule_prob[:, :, 0] # (B, NT, NT_T, NT_T)\n",
"ab = rule_prob[:, :, 1] # (B, NT, NT_T, NT_T)\n",
"\n",
"lr_lp = lr.logsumexp((-1, -2))\n",
"ab_lp = ab.logsumexp((-1, -2))\n",
"print(rule_prob.shape, lr_lp.shape)"
]
},
{
"cell_type": "code",
"execution_count": 253,
"id": "fabb1c9b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([0.3788, 0.3788, 0.3788, 0.3788], device='cuda:0') uniform: 4.394449154672439\n"
]
}
],
"source": [
"part_prob = rule_prob.flatten(-2).log_softmax(-1)\n",
"part_ent = -(part_prob.exp() * part_prob).sum(-1)\n",
"print(part_ent.mean((-1, -2)), \"uniform:\", np.log(part_prob.shape[-1]))"
]
},
{
"cell_type": "code",
"execution_count": 254,
"id": "4761b8b9",
"metadata": {},
"outputs": [],
"source": [
"if ipcfg_la:\n",
" print((rule_prob.flatten(-2) == part_prob).all())"
]
},
{
"cell_type": "code",
"execution_count": 255,
"id": "7c3b2b93",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[1.0000, 1.0000],\n",
" [1.0000, 1.0000],\n",
" [1.0000, 1.0000],\n",
" [1.0000, 1.0000],\n",
" [1.0000, 1.0000],\n",
" [1.0000, 1.0000]], device='cuda:0')\n",
"tensor([[1.0517e+00, 1.0517e+00],\n",
" [5.1468e-01, 5.1468e-01],\n",
" [2.3368e-06, 2.3368e-06],\n",
" [7.0624e-01, 7.0624e-01],\n",
" [2.9700e-06, 2.9700e-06],\n",
" [5.6108e-07, 5.6108e-07]], device='cuda:0')\n",
"tensor(0.3788, device='cuda:0') tensor(nan, device='cuda:0')\n"
]
}
],
"source": [
"lr_ab = torch.stack((lr_lp, ab_lp), -1)\n",
"print(lr_ab[0].exp())\n",
"print(part_ent[0])\n",
"\n",
"fulcrum = 0.95\n",
"above = part_ent[lr_ab.exp() >= fulcrum].mean()\n",
"below = part_ent[lr_ab.exp() < fulcrum].mean()\n",
"print(above, below)"
]
},
{
"cell_type": "code",
"execution_count": 256,
"id": "acb68a7d",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# rule_prob[0, 1, 0].flatten(-2).exp() #.sum()"
]
},
{
"cell_type": "code",
"execution_count": 257,
"id": "b37cb133",
"metadata": {},
"outputs": [],
"source": [
"# x_entmax = entmax_bisect(rule_prob[1, 1, 0].flatten(-2), 1.05)\n",
"# x_entmax[x_entmax.isnan()] = 0\n",
"# x_entmax"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "10732388",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 258,
"id": "03d3d9ab",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([0.0094, 0.0094, 0.0094, 0.0094], device='cuda:0')\n",
"tensor([0.6931, 0.6931, 0.6931, 0.6931], device='cuda:0')\n",
"tensor([2.9802e-08, 2.9802e-08, 2.9802e-08, 2.9802e-08], device='cuda:0')\n"
]
}
],
"source": [
"if not ipcfg_la:\n",
" ent = -(rule_prob.exp() * rule_prob).sum(2).mean((-1, -2, -3))\n",
" print(ent)\n",
"\n",
" rule_lp = rule_prob.log_softmax(2)\n",
" ent = -(rule_lp.exp() * rule_lp).sum(2).mean((-1, -2, -3))\n",
" print(ent)\n",
"\n",
" rule_ent = -(lr_ab.exp() * lr_ab).sum(-1)\n",
" # print(rule_ent)\n",
" rule_ent = rule_ent.mean(-1)\n",
" print(rule_ent)\n",
"else:\n",
" lr_ent = -(lr.exp() * lr).sum((-1, -2)).mean(-1) # (B,)\n",
" ab_ent = -(ab.exp() * ab).sum((-1, -2)).mean(-1) # (B,)\n",
" print(lr_ent, ab_ent)"
]
},
{
"cell_type": "code",
"execution_count": 259,
"id": "2f8a6460",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([1., 1., 1., 1.], device='cuda:0')\n",
"tensor([1., 1., 1., 1.], device='cuda:0')\n",
"tensor([-0., -0., -0., -0.], device='cuda:0')\n"
]
}
],
"source": [
"if not ipcfg_la:\n",
" print(lr_lp.exp().mean(-1))\n",
" print(ab_lp.exp().mean(-1))\n",
"\n",
" lr_prefer = (lr_lp - np.log(lr_lp.shape[-1])).logsumexp(-1)\n",
" ab_prefer = (ab_lp - np.log(ab_lp.shape[-1])).logsumexp(-1)\n",
" lr_ab_prefer = torch.stack([lr_prefer, ab_prefer], -1)\n",
" ent = -(lr_ab_prefer.exp() * lr_ab_prefer).sum(-1)\n",
" print(ent)"
]
},
{
"cell_type": "markdown",
"id": "d9c8e1d1",
"metadata": {},
"source": [
"#### Start rules"
]
},
{
"cell_type": "code",
"execution_count": 260,
"id": "31c9f6e1",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.08 0.07 0.26 0.07 0.26 0.26]\n",
" [0.08 0.07 0.26 0.07 0.26 0.26]\n",
" [0.08 0.07 0.26 0.07 0.26 0.26]\n",
" [0.08 0.07 0.26 0.07 0.26 0.26]]\n"
]
}
],
"source": [
"with np.printoptions(precision=2, suppress=True):\n",
" if not ipcfg_la:\n",
" print(np.exp(root_prob.cpu().numpy()))\n",
" else:\n",
" B = root_prob.shape[0]\n",
" root_p = root_prob.view((B, 2, -1))\n",
" print(np.exp(root_p.cpu().numpy()), root_p.shape)"
]
},
{
"cell_type": "markdown",
"id": "38d61ee7",
"metadata": {},
"source": [
"#### Pre-terminal rules"
]
},
{
"cell_type": "code",
"execution_count": 261,
"id": "73b147f9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.34 0. 1. ] [0]\n",
"[0.34 0. 1. ] [0]\n",
"[0.34 0. 1. ] [0]\n",
"[0.66 0. 0. ] [1]\n",
"[0.66 0. 0. ] [1]\n",
"[0.66 0. 0. ] [1]\n",
"[0.34 0. 1. ] [0]\n",
"[0.34 0. 1. ] [0]\n",
"[0.34 0. 1. ] [0]\n",
"[0.34 0. 1. ] [0]\n",
"[0.34 0. 1. ] [0]\n",
"[0.34 0. 1. ] [0]\n",
"[0.34 0. 1. ] [0]\n",
"[0.34 0. 1. ] [0]\n",
"[0.34 0. 1. ] [0]\n",
"[0.34 0. 1. ] [0]\n",
"[0.34 0. 1. ] [0]\n",
"[0. 1. 0.] [2]\n",
"[0.34 0. 1. ] [0]\n",
"[0.34 0. 1. ] [0]\n",
"[0.34 0. 1. ] [0]\n",
"[0.34 0. 1. ] [0]\n",
"[0.34 0. 1. ] [0]\n",
"[0. 1. 0.] [2]\n",
"[0.34 0. 1. ] [0]\n",
"[0.34 0. 1. ] [0]\n",
"[0.34 0. 1. ] [0]\n",
"[0.34 0. 1. ] [0]\n",
"[0.34 0. 1. ] [0]\n",
"[0. 1. 0.] [2]\n",
"[0.34 0. 1. ] [0]\n",
"[0.34 0. 1. ] [0]\n",
"[0.34 0. 1. ] [0]\n",
"[0.34 0. 1. ] [0]\n",
"[0.34 0. 1. ] [0]\n",
"[0. 1. 0.] [2]\n"
]
}
],
"source": [
"v_seq = batch_dict[\"image\"]\n",
"\n",
"isample = 0\n",
"\n",
"x = v_seq[isample].unsqueeze(-1).cpu().numpy()\n",
"if not ipcfg_la:\n",
" x_pv = term_prob[isample].cpu().numpy()\n",
"else: # B, L, T\n",
" B, L, T = term_prob.shape\n",
" x_pv = term_prob[isample].view((L, 2, -1))\n",
" x_pv = x_pv.cpu().numpy()\n",
"\n",
"with np.printoptions(precision=2, suppress=True):\n",
" for pv, xx in zip(x_pv, x):\n",
" print(np.exp(pv), xx)"
]
},
{
"cell_type": "code",
"execution_count": 262,
"id": "6194323b",
"metadata": {},
"outputs": [],
"source": [
"# np.exp(x_pv)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e1680721",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "3b24330b",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "55465616",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "a79a0cf7",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment