00039-使用 pandas 为 TransE 预处理数据


前言

# data_preprocessing
#
# created by LuYF-Lemon-love  on October 31, 2022
#
# 该脚本为 TransE 生成数据集
#
# prerequisites:
#     ../origin_data/raw_data.csv
#
# 输出最终的数据
# output:
#     ../final_data/relation2id.txt
#     ../final_data/entity2id.txt
#     ../final_data/train2id.txt
#     ../final_data/valid2id.txt
#     ../final_data/test2id.txt

操作系统:Ubuntu 20.04.5 LTS

生成目录

$ mkdir -p ../final_data

导入第三方库

import numpy as np
import pandas as pd
import random

读取原始数据

df = pd.read_csv('../origin_data/raw_data.csv')

# 去掉 '病理', '诊断', '预防' 三列

df = df.loc[:, [column for column in df.columns if column not in ['病理', '诊断', '预防']]]

生成 relation2id.txt

relation2id = {}
f = open('../final_data/relation2id.txt', 'w')
f.write('%d\n' % (len(df.columns[1:])))
for id, relation in enumerate(df.columns[1:]):
    f.write("%s\n" % relation)
    relation2id[relation] = id
f.close()

生成 entity2id.txt

entitys = set()
triples = []

for index, Series in df.iterrows():
    head = Series['疾病名称'].replace(' ', '-')
    for column_name in df.columns[1:]:
        if Series[column_name] is not np.nan:
            for tail in Series[column_name].strip(' ;').split(';'):
                if tail != '':
                    tail = ''.join([ch for ch in tail if ch not in [' ', '\t', '\n', '\r']])
                    entitys.add(tail)
                    triples.append([head, tail, column_name])
    entitys.add(head)

entity2id = {}
f = open('../final_data/entity2id.txt', 'w')
f.write('%d\n' % (len(entitys)))
for id, entity in enumerate(list(entitys)):
    f.write('%s\n' % entity)
    entity2id[entity] = id
f.close()

shuffle 数据集

random.seed(42)
random.shuffle(triples)
total = len(triples)

生成 train2id.txt, valid2id.txt, test2id.txt

train_set = triples[:int(total * 0.8)]
valid_set = triples[int(total * 0.8):int(total * 0.9)]
test_set = triples[int(total * 0.9):]

f= open('../final_data/train2id.txt', 'w')
f.write('%d\n' % (len(train_set)))
for row in train_set:
    f.write('%d\t%d\t%d\n' % (entity2id[row[0]], entity2id[row[1]], relation2id[row[2]]))
f.close()

f= open('../final_data/valid2id.txt', 'w')
f.write('%d\n' % (len(valid_set)))
for row in valid_set:
    f.write('%d\t%d\t%d\n' % (entity2id[row[0]], entity2id[row[1]], relation2id[row[2]]))
f.close()

f= open('../final_data/test2id.txt', 'w')
f.write('%d\n' % (len(test_set)))
for row in test_set:
    f.write('%d\t%d\t%d\n' % (entity2id[row[0]], entity2id[row[1]], relation2id[row[2]]))
f.close()

final_data

该数据集是 医药知识图谱 的实体子集. raw_data.csv医药知识图谱 的原始数据. 该数据集是使用 data_preprocessing.ipynb 脚本移除了 raw_data.csv病理, 诊断, 预防三列数据得到的. 一共 32,831 个三元组, 12,728 个实体, 10 个关系, 被随机地分成了训练集 (26,264 个), 验证集 (3,283 个), 测试集 (3,284 个).

  • entity2id.txt: 第一行是实体个数. 其余行是实体名, 每行一个. (实体名内不能有空白符, 实体的 ID0 开始, 第二行的第一个实体的 ID 为 0, 第三行的第二个实体的 ID 为 1, …)

  • relation2id.txt: 第一行是关系个数. 其余行是关系名, 每行一个. (关系名内不能有空白符, 关系的 ID0 开始, 第二行的第一个关系的 ID 为 0, 第三行的第二个关系的 ID 为 1, …)

  • train2id.txt: 训练文件. 第一行是训练集三元组的个数. 其余行是 (e1, e2, rel) 格式的三元组, 每行一个. e1, e2 是实体 ID, rel 是关系 ID.

  • valid2id.txt: 验证文件. 第一行是验证集三元组的个数. 其余行是 (e1, e2, rel) 格式的三元组, 每行一个. e1, e2 是实体 ID, rel 是关系 ID.

  • test2id.txt: 测试文件. 第一行是测试集三元组的个数. 其余行是 (e1, e2, rel) 格式的三元组, 每行一个. e1, e2 是实体 ID, rel 是关系 ID.

结语

第三十九篇博文写完,开心!!!!

今天,也是充满希望的一天。


文章作者: LuYF-Lemon-love
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 LuYF-Lemon-love !
  目录