


import torch
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import numpy as np

# 设定一些超参数
K = 100  # number of negative samples
C = 3  # nearby words threshold
MAX_VOCAB_SIZE = 30000  # the vocabulary size
BATCH_SIZE = 128  # the batch size

# 打开train数据集
with open("data/text8.train.txt", "r") as fin:
    text = fin.read()

# 分割所有单词
text = [word for word in text.split()]
# {word:number}
vocab = dict(Counter(text).most_common(MAX_VOCAB_SIZE-1))
# 新增 "<unk>"字符
vocab["<unk>"] = len(text) - np.sum(list(vocab.values()))

# 字符集列表
idx_to_word = [word for word in vocab.keys()]
# 字符集和对应的位置索引构成的词典
word_to_idx = {word: i for i, word in enumerate(idx_to_word)}

# 获取单词出现的个数
word_counts = np.array([count for count in vocab.values()], dtype=np.float32)
# 计算频率
word_freqs = word_counts / np.sum(word_counts)
# 0.75 次幂
word_freqs = word_freqs ** (3./4.)
# 归一化
word_freqs = word_freqs / np.sum(word_freqs) # 用来做 negative sampling

VOCAB_SIZE = len(idx_to_word)

class WordEmbeddingDataset(torch.utils.data.Dataset):
    def __init__(self, text, word_to_idx, idx_to_word, word_freqs, word_counts):
        super(WordEmbeddingDataset, self).__init__()
        # 将单词转换成数字索引
        self.text_encoded = [word_to_idx.get(t, VOCAB_SIZE-1) for t in text]
        self.text_encoded = torch.Tensor(self.text_encoded).long()
        # dict:word->index
        self.word_to_idx = word_to_idx
        # list: index->word
        self.idx_to_word = idx_to_word
        # 单词频率
        self.word_freqs = torch.Tensor(word_freqs)
        # 单词次数统计
        self.word_counts = torch.Tensor(word_counts)

    def __len__(self):
        return len(self.text_encoded)

    def __getitem__(self, idx):
        # 中心词
        center_word = self.text_encoded[idx]
        # 周边词
        pos_indices = list(range(idx - C, idx)) + list(range(idx + 1, idx + C + 1))
        pos_indices = [i % len(self.text_encoded) for i in pos_indices]
        # 正采样
        pos_words = self.text_encoded[pos_indices]
        # 负采样
        neg_words = torch.multinomial(self.word_freqs, K * pos_words.shape[0], True)
        return center_word, pos_words, neg_words

dataset = WordEmbeddingDataset(text, word_to_idx, idx_to_word, word_freqs, word_counts)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)


