> 文章列表 > python基于机器学习的姓名预测性别网页app开发

python基于机器学习的姓名预测性别网页app开发

python基于机器学习的姓名预测性别网页app开发

前言

做这个项目的起因是之前csdn给我推荐了一个问答:基于机器学习的姓名预测性别的手机app开发。我点进去发现已经有人回答了,链接点进去一看,好家伙,这不是查表算概率吗,和机器学习有半毛钱关系。而且我觉得用姓名预测性别挺扯淡的,去查了一下,发现某知名爱国企业和国外的都有提供姓名预测性别的api,看来是可以尝试做的。

吐槽完了,先上最后做出来的结果:
python基于机器学习的姓名预测性别网页app开发
准确性还是可以的,并且支持人名批量查询,整个网页响应速度也很快。

环境

我用的是python3.10,需要安装以下包:
numpy
pandas
pypinyin
tensorflow-cpu
plotly(网页)
dash(网页)

方法

1.如何拿到人名和性别的数据。这个我一开始去搜那篇查表文章背后所用的数据库从哪来的,在github上找到了它的人名出现频率图,但是没有找到原始数据库,据说是从什么泄露的kaifangjilu里拿出来的,我一想这tm不是违法吗,难道没有合法手段拿到这样的数据集资源了吗?我还是在github上找到了120万人名和性别的数据集:
点进去找到Chinese_Names_Corpus_Gender(120W).txt这个文件
2.如何对中文姓名进行特征提取,转化为机器能理解的语言。 这个想来想去还是决定把中文先转换为无注音的拼音,然后对每一个字母进行字母表数字的转换,事实证明确实效果不错。

具体如下图所示:
python基于机器学习的姓名预测性别网页app开发

代码

代码分为三块,数据准备代码,数据训练代码,网页app代码。

数据准备代码

首先把下载下来的txt转换为csv文件:
python基于机器学习的姓名预测性别网页app开发
把前面的东西去掉,然后文件后缀一改,就摇身一变成为了以逗号为分割的经典csv文件。
python基于机器学习的姓名预测性别网页app开发
接下来开始读取处理数据:

import pandas as pd
df = pd.read_csv("test.csv")
df

这里我是在notebook里运行的,结果如下:

python基于机器学习的姓名预测性别网页app开发

我们首先需要把性别转换为0,1表示男,女:

df['sex'].replace(['男', '女','未知'],[0, 1, 2], inplace=True)

然后批量转换姓名并保存到新的csv文件中:

from pypinyin import lazy_pinyin
import time
count = 0
a1 = time.time()
for x in df['dict']:list_pinyin = lazy_pinyin(x) #["a","zuo"]c = ''.join(list_pinyin) #["azuo"]num_pinyin = [max(0.0, ord(char)-96.0) for char in c]num_pinyin_pad = num_pinyin + [0.0] * max(0, 20 - len(num_pinyin))df['dict'][count] = num_pinyin_pad[:15] #为了使输入向量固定长度,取前15个字符。count+=1a2 = time.time()if count % 10000 == 0:print(a2-a1)
df.to_csv('after_2.csv')

这里时间挺久的,因为数据量大,大概需要个半小时,我让它每10000个数据打印一下运行时间,可以去掉。然后有个细节就是因为要输入模型,所以要固定向量长度,即短的名字给它补0,长的名字给它截掉,一律取前十五个字母。 保存完csv之后就可以退出了。

数据训练代码

先把数据读进来,因为发现二分类表现更佳,所以我们这里排除性别为2也就是未知的名字。

import pandas as pd
import numpy as np
df = pd.read_csv('after_2.csv')
df_binary = df[df['sex']!=2]

准备输入向量:

import json
test_list = df_binary['dict'].values.tolist()
for i in range(len(test_list)):test_list[i] = eval(test_list[i])
X = np.array(test_list,dtype = np.float32)
y = np.asarray(df_binary['sex'].values.tolist())

其中X的形状是(1050353, 15),y的形状是(1050353,)。

划分训练集和测试集:

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X,
y,test_size=0.2,random_state=0)

准备模型:

from tensorflow.keras import Sequential
from tensorflow.keras.layers import Embedding, Bidirectional, LSTM, Dense
from tensorflow.keras.optimizers import Adamdef lstm_model(num_alphabets=27, name_length=15, embedding_dim=256):model = Sequential([Embedding(num_alphabets, embedding_dim, input_length=15),Bidirectional(LSTM(units=128, recurrent_dropout=0.2, dropout=0.2)),Dense(1, activation="sigmoid")])model.compile(loss='binary_crossentropy',optimizer=Adam(learning_rate=0.001),metrics=['accuracy'])return model

只有一层LSTM,cpu也可以轻松训练(指训练一个epoch需要半小时)

训练:

import numpy as np
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping
# Step 1: Instantiate the model
model = lstm_model(num_alphabets=27, name_length=15, embedding_dim=256)# Step 2: Split Training and Test DataX_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.2,random_state=0)# Step 3: Train the model
callbacks = [EarlyStopping(monitor='val_accuracy',min_delta=1e-3,patience=5,mode='max',restore_best_weights=True,verbose=1),
]history = model.fit(x=X_train,y=y_train,batch_size=64,epochs=3,validation_data=(X_test, y_test),callbacks=callbacks)# Step 4: Save the model
model.save('boyorgirl.h5')# Step 5: Plot accuracies
plt.plot(history.history['accuracy'], label='train')
plt.plot(history.history['val_accuracy'], label='val')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

网络这一块是别人造好的轮子,我直接拿来用了,因为训练速度比较慢所以我这里只训练了三轮,训练集和测试集的准确率都在提升,已经接近0.88,说明继续训练还有提升空间:

python基于机器学习的姓名预测性别网页app开发

网页开发代码

这块比较长,就不解释了,需要注意的是我结尾设置app是在0.0.0.0上的3000端口运行的,也就是说如果你把它部在服务器上,可以直接访问服务器的ip地址加端口号访问网页。如果是本地查看设置127.0.0.1即可。
另外需要在同目录底下准备一个faq.md的说明文件,是放在网页底部进行说明的,比如下图:
python基于机器学习的姓名预测性别网页app开发

import os
import pandas as pd
import numpy as np
import re
from tensorflow.keras.models import load_model
from pypinyin import lazy_pinyin
import plotly.express as px
import dash
from dash import dash_table
import dash_bootstrap_components as dbc
from dash import dcc
from dash import html
from dash.dependencies import Input, Output, Statepred_model = load_model('boyorgirl.h5')# Setup the Dash App
external_stylesheets = [dbc.themes.LITERA]
app = dash.Dash(__name__, external_stylesheets=external_stylesheets)# Server
server = app.server# FAQ section
with open('faq.md', 'r') as file:faq = file.read()# App Layout
app.layout = html.Table([html.Tr([html.H1(html.Center(html.B('男孩或者女孩?'))),html.Div(html.Center("根据名字预测性别"),style={'fontSize': 20}),html.Br(),html.Div(dbc.Input(id='names',value='李泽,李倩',placeholder='输入多个名字请用逗号或者空格分开',style={'width': '700px'})),html.Br(),html.Center(children=[dbc.Button('提交',id='submit-button',n_clicks=0,color='primary',type='submit'),dbc.Button('重置',id='reset-button',color='secondary',type='submit',style={"margin-left": "50px"})]),html.Br(),dcc.Loading(id='table-loading',type='default',children=html.Div(id='predictions',children=[],style={'width': '700px'})),dcc.Store(id='selected-names'),html.Br(),dcc.Loading(id='chart-loading',type='default',children=html.Div(id='bar-plot', children=[])),html.Br(),html.Div(html.Center(html.B('关于该项目')),style={'fontSize': 20}),dcc.Markdown(faq, style={'width': '700px'})])
],style={'marginLeft': 'auto','marginRight': 'auto'})# Callbacks
@app.callback([Output('submit-button', 'n_clicks'),Output('names', 'value')], Input('reset-button', 'n_clicks'),State('names', 'value'))
def update(n_clicks, value):if n_clicks is not None and n_clicks > 0:return -1, ''else:return 0, value@app.callback([Output('predictions', 'children'),Output('selected-names', 'data')], Input('submit-button', 'n_clicks'),State('names', 'value'))
def predict(n_clicks, value):if n_clicks >= 0:# Split on all non-alphabet characters# Restrict to first 10 names onlynames = re.findall(r"\\w+", value)# Convert to dataframepred_df = pd.DataFrame({'name': names})list_list = []# Preprocessfor x in names:list_pinyin = lazy_pinyin(x)c = ''.join(list_pinyin)num_pinyin = [max(0.0, ord(char)-96.0) for char in c]num_pinyin_pad = num_pinyin + [0.0] * max(0, 20 - len(num_pinyin))list_list.append(num_pinyin_pad[:15])# Predictionsresult = pred_model.predict(list_list).squeeze(axis=1)pred_df['男还是女'] = ['女' if logit > 0.5 else '男' for logit in result]pred_df['可能性'] = [logit if logit > 0.5 else 1.0 - logit for logit in result]# Format the outputpred_df['name'] = namespred_df.rename(columns={'name': '名字'}, inplace=True)pred_df['可能性'] = pred_df['可能性'].round(2)pred_df.drop_duplicates(inplace=True)return [dash_table.DataTable(id='pred-table',columns=[{'name': col,'id': col,} for col in pred_df.columns],data=pred_df.to_dict('records'),filter_action="native",filter_options={"case": "insensitive"},sort_action="native",  # give user capability to sort columnssort_mode="single",  # sort across 'multi' or 'single' columnspage_current=0,  # page number that user is onpage_size=10,  # number of rows visible per pagestyle_cell={'fontFamily': 'Open Sans','textAlign': 'center','padding': '10px','backgroundColor': 'rgb(255, 255, 204)','height': 'auto','font-size': '16px'},style_header={'backgroundColor': 'rgb(128, 128, 128)','color': 'white','textAlign': 'center'},export_format='csv')], nameselse:return [], ''@app.callback(Output('bar-plot', 'children'), [Input('submit-button', 'n_clicks'),Input('predictions', 'children'),Input('selected-names', 'data')
])
def bar_plot(n_clicks, data, selected_names):if n_clicks >= 0:# Bar Chartdata = pd.DataFrame(data[0]['props']['data'])fig = px.bar(data,x="可能性",y="名字",color='男还是女',orientation='h',color_discrete_map={'男': 'dodgerblue','女': 'lightcoral'})fig.update_layout(title={'text': '预测正确的可能性','x': 0.5},yaxis={'categoryorder': 'array','categoryarray': selected_names,'autorange': 'reversed',},xaxis={'range': [0, 1]},font={'size': 14},width=700)return [dcc.Graph(figure=fig)]else:return []if __name__ == '__main__':app.run_server(host='0.0.0.0', port='3000', proxy=None, debug=False)