Last active
January 28, 2026 11:06
-
-
Save naoh16/75957105cdcf78c4c29b2e31761bd379 to your computer and use it in GitHub Desktop.
MMDA実習2: クラス定義関連
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # src/dm2/restaurant/__init__.py |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # src/dm2/restaurant/dataset.py | |
| # see also: PB5_ds_fst-ml.ipynb | |
| class DatasetRestaurantRetrieval: | |
| # レストラン検索用のBIOラベルの一覧: これはクラス変数 | |
| RESTAURANT_BIO_LABELS = [ | |
| 'B-budget', 'I-budget', | |
| 'B-mood', 'I-mood', | |
| 'B-place', 'I-place', | |
| 'B-genre', 'I-genre', | |
| 'B-style', 'I-style', | |
| 'B-rate', 'I-rate', | |
| 'O' | |
| ] | |
| def __init__(self): | |
| # 実データは以下の2変数に格納する | |
| self._x = [] | |
| self._y = [] | |
| # クラス外からは以下の変数名でアクセスする | |
| # cf. C言語のポインタ | |
| self.train_x = None | |
| self.train_y = None | |
| self.test_x = None | |
| self.test_y = None | |
| def load(self, csv_filename, num_train=80): | |
| # レストランデータ | |
| with open(csv_filename, 'r', encoding='utf-8') as f: | |
| lines = f.readlines() | |
| # データの前処理 | |
| for line in lines: | |
| _elems = line.rstrip().split(',') | |
| # _ = _chunks[0] # ID number | |
| # _ = _chunks[1] # raw text : 通常の日本語テキスト | |
| d = _elems[2].split('/') # chunked text: 分割済みの単語系列 | |
| a = _elems[3].split('/') # BIO labels : 単語毎の正解ラベルの系列 | |
| self._x.append(d) | |
| self._y.append(a) | |
| # 学習用と評価用に分割 | |
| self.train_x = self._x[:num_train] | |
| self.train_y = self._y[:num_train] | |
| self.test_x = self._x[num_train:] | |
| self.test_y = self._y[num_train:] | |
| # Debug information | |
| print(f'DEBUG; Dataset: train={len(self.train_x)}, test={len(self.test_x)}') | |
| @classmethod | |
| def slot_parser(cls, words, bio_labels): | |
| """BIO label系列をPB4で示したスロットの配列に変換する | |
| 1つのスロットは (Key, Value) からなる tuple/list である. | |
| """ | |
| result = [] | |
| for w, l in zip(words, bio_labels): | |
| if l == 'O': | |
| continue | |
| elif l[:2] == 'B-' or l[:2] == 'B_': # 文字列は配列のようにも扱える (read-only) | |
| slot_name = l[2:] | |
| result.append([slot_name, w]) | |
| elif l[:2] == 'I-' or l[:2] == 'I_': | |
| # 直前の B ラベルとnameが一致しているなら単語を結合する | |
| if len(result) > 0 and result[-1][0] == l[2:]: | |
| result[-1][1] += w | |
| else: | |
| print(f'Warning: The label {l} is orphan label?') | |
| else: | |
| # ここにたどり着いている場合,ラベル文字列そのものが不正である. | |
| print(f'Warning: The label {l} is unknown.') | |
| return result |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # dm2/src/dm2/restaurant/dm.py | |
| # see also: PB5_ds_fst-ml.ipynb | |
| # 演習課題のため,budget 関連の処理を省いている | |
| class DMRestaurantRetrieval: | |
| def __init__(self): | |
| # <<状態の定義>> | |
| self.states = [ | |
| (0, 'こんにちは。京都レストラン案内です。どの地域のレストランをお探しですか。'), | |
| (1, 'どのような料理がお好みですか。'), | |
| # (2, 'ご予算はおいくらぐらいですか。'), | |
| (3, '{place}で、{genre}のレストランを検索します。'), # 最終状態 | |
| (4, '地域名を「京都駅近辺のようにおっしゃってください。'), | |
| (5, '和食・洋食・中華・ファストフードからお選びください。'), | |
| # (6, '予算を「3000円以下」のようにおっしゃってください。') | |
| ] | |
| self.start_state_num = 0 | |
| self.end_state_nums = [3] | |
| # <<遷移の定義>> | |
| # 遷移元状態番号、遷移先状態番号、遷移条件(スロット名) | |
| self.transitions = [ | |
| (0, 1, 'place'), | |
| (0, 4, None), | |
| # (1, 2, 'genre'), | |
| (1, 3, 'genre'), | |
| (1, 5, None), | |
| # (2, 3, 'budget'), | |
| # (2, 6, None), | |
| (4, 1, 'place'), | |
| (4, 4, None), | |
| # (5, 2, 'genre'), | |
| (5, 3, 'genre'), | |
| (5, 5, None), | |
| # (6, 3, 'budget'), | |
| # (6, 6, None) | |
| ] | |
| # <<対話処理中の変数群>> | |
| # PB4の説明で global として利用していた変数を, | |
| # クラスのフィールドとして定義しようとしている. | |
| ## 現在の内部状態 | |
| self.current_state_num = self.start_state_num | |
| ## 遷移条件の保持 | |
| self.context_user_utterance = [] | |
| def reset(self): | |
| """初期状態にリセットする | |
| """ | |
| self.current_state_num = self.start_state_num | |
| self.context_user_utterance = [] | |
| def enter(self, user_utterance): | |
| """入力であるユーザ発話に応じてシステム発話を出力し、内部状態を遷移させる | |
| ただし、ユーザ発話の情報は「意図、スロット名、スロット値」のlistとする | |
| """ | |
| # スロット抽出の結果が空の場合もあるので: | |
| if len(user_utterance) == 0: | |
| user_utterance = [['', '']] | |
| # 最初の0番目のindexは1発話に対して複数のスロットが抽出された場合に対応するための措置. | |
| # ここでは1発話につき1つのフレームしか含まれないという前提 | |
| input_frame_name = user_utterance[0][0] | |
| input_frame_value = user_utterance[0][1] | |
| system_utterance = "" | |
| # 現在の状態からの遷移に対して入力がマッチするか検索 | |
| for trans in [t for t in self.transitions if t[0] == self.current_state_num]: | |
| # 無条件に遷移 | |
| if trans[2] is None: | |
| self.current_state_num = trans[1] | |
| return True | |
| # 条件にマッチすれば遷移 | |
| if trans[2] == input_frame_name: | |
| self.context_user_utterance.append((input_frame_name, input_frame_value)) | |
| self.current_state_num = trans[1] | |
| return True | |
| return False | |
| def get_system_utterance(self): | |
| """指定された状態に対応するシステムの発話を取得 | |
| """ | |
| utt_text = "" | |
| for _state in [s for s in self.states if s[0] == self.current_state_num]: | |
| utt_text = _state[1] | |
| # プレースメントホルダーの置換 | |
| # 例: "{place}で検索します。"-->"京都駅で検索します、" | |
| for ctx in self.context_user_utterance: | |
| utt_text = utt_text.replace(f'{{{ctx[0]}}}', ctx[1]) | |
| # 補足:↑の行は,もう少しかみ砕くと, | |
| # ctx = {'place', '京都駅'} | |
| # の時に, | |
| # utt_text = utt_text.replace('{place}', '京都駅') | |
| # という関数呼び出しになるようにしている. | |
| # f''構文で直接「京都駅」に置換しているわけではなく, | |
| # f''構文でreplaceの第一引数を用意しているだけ. | |
| return utt_text | |
| def is_finished(self): | |
| """対話が終了状態に到達しているかどうかを判定 | |
| """ | |
| return self.current_state_num in self.end_state_nums |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # dm2/src/dm2/restaurant/slu.py | |
| # see also: PB5_ds_fst-ml.ipynb | |
| import os | |
| import pickle | |
| import sys | |
| import sklearn_crfsuite | |
| from gensim.models import KeyedVectors # <-- word2vecを利用する際に必要 | |
| class SLURestaurantRetrieval: | |
| def __init__(self, modelfile=None): | |
| """コンストラクタ | |
| """ | |
| self.classifier = None # これまでの資料の clf 変数に相当 | |
| self.model_filename = 'gen/slu-slot-restaurant-crf.model' if modelfile is None else modelfile | |
| def load(self, model_filename=None): | |
| """学習済みモデルの読み込み | |
| """ | |
| if model_filename is not None: | |
| self.model_filename = model_filename | |
| with open(self.model_filename, 'rb') as f: | |
| self.classifier = pickle.load(f) | |
| def fit(self, x, y): | |
| """機械学習モデルの学習(CRF + One-hot encoding) | |
| """ | |
| # 特徴量の返還 | |
| x_feat = self.transform_x(x) | |
| # モデルの準備と学習 | |
| self.classifier = sklearn_crfsuite.CRF() | |
| self.classifier.fit(x_feat, y) | |
| # モデルファイルの保存 | |
| pickle.dump(self.classifier, open(self.model_filename, 'wb')) | |
| def predict(self, x): | |
| """機械学習モデルを利用した推論(CRF + One-hot encoding) | |
| """ | |
| if isinstance(x[0], list): | |
| # 複数の入力が二重配列として入力された場合 | |
| x_feat = self.transform_x(x) | |
| return self.classifier.predict(x_feat) | |
| else: | |
| # 1入力が単純配列として入力された場合 | |
| x_feat = self.transform_x(x) | |
| return self.classifier.predict(x_feat)[0] | |
| def transform_x(self, x): | |
| """CRFSuite用の特徴量に変換(One-hot encoding) | |
| 本来はone-hot encodingを行うが,CRFSuiteが自動で行ってくれるため, | |
| ここでは実質何もしないでよい | |
| """ | |
| return x | |
| class SLURestaurantRetrievalW2V(SLURestaurantRetrieval): | |
| def __init__(self, modelfile=None): | |
| """コンストラクタ | |
| """ | |
| super().__init__() | |
| self.model_filename = 'gen/slu-slot-restaurant-crf-w2v.model' if modelfile is None else modelfile | |
| # word2vecモデルの読み込み. バイナリモデル (kv) があるならそちらを使う. | |
| # 初回は .txt.bz2 の処理で時間がかかるが,2回目以降は高速に読み込める. | |
| sys.stderr.write(f'DEBUG: Load Word2vec model: ') | |
| if os.path.exists('gen/jawiki.entity_vectors.100d.kv'): | |
| sys.stderr.write(f'gen/jawiki.entity_vectors.100d.kv\n') | |
| self.model_w2v = KeyedVectors.load('gen/jawiki.entity_vectors.100d.kv', mmap='r') | |
| else: | |
| sys.stderr.write(f'gen/jawiki.entity_vectors.100d.txt.bz2\n') | |
| self.model_w2v = KeyedVectors.load_word2vec_format( | |
| 'gen/jawiki.entity_vectors.100d.txt.bz2', | |
| binary=False | |
| ) | |
| self.model_w2v.save('gen/jawiki.entity_vectors.100d.kv') | |
| sys.stderr.write(f'DEBUG: Loaded Word2vec model with vector size = {self.model_w2v.vector_size}\n') | |
| def transform_x(self, x): | |
| """CRFSuite用の特徴量に変換(Word2vec encoding) | |
| """ | |
| x_w2v = [] | |
| if isinstance(x[0], list): | |
| for datum in x: | |
| feature_vec = self.make_words_vec_with_w2v_for_crfsuite(datum) | |
| x_w2v.append(feature_vec) | |
| else: | |
| feature_vec = self.make_words_vec_with_w2v_for_crfsuite(x) | |
| x_w2v.append(feature_vec) | |
| return x_w2v | |
| def make_words_vec_with_w2v_for_crfsuite(self, words): | |
| sentence_vec = [None] * len(words) | |
| num_valid_word = 0 | |
| for n, w in enumerate(words): | |
| if w in self.model_w2v: | |
| # sentence_vec[n] = {'w2v': model_w2v[w][0]} | |
| _features = dict() | |
| for k, v in enumerate(self.model_w2v[w]): | |
| _features[f'w2v_{k:03d}'] = v | |
| sentence_vec[n] = _features | |
| else: | |
| # sentence_vec[n] = {'w2v': np.zeros(model_w2v.vector_size)} | |
| _features = dict() | |
| for k in range(self.model_w2v.vector_size): | |
| _features[f'w2v_{k:03d}'] = 0 | |
| sentence_vec[n] = _features | |
| return sentence_vec |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment