博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Tensorflow rnn-word2vec-电影评论
阅读量:5789 次
发布时间:2019-06-18

本文共 5796 字,大约阅读时间需要 19 分钟。

import pandas as pdimport reimport numpy as npimport os复制代码
from gensim.models import word2vec复制代码
data_t = pd.read_csv('labeledTrainData.tsv',sep='\t')复制代码
data_t.shape复制代码
(25000, 3)复制代码
if not os.path.exists('mymodel'):    if not os.path.exists('imdb_text'):        data_un = pd.read_csv('unlabeledTrainData.tsv',header=0, delimiter="\t",quoting=3 )        pat = re.compile(r'[A-Za-z]+')        with open('imdb_text','a',encoding = 'utf-8') as f:            for rev in data_un.review:                str_list = pat.findall(rev)                str_list = [x.lower() for x in str_list]                string = ' '.join(str_list)                f.write(string + '\n')            del data_un    sentences =word2vec.Text8Corpus("imdb_text")  # 加载语料      model =word2vec.Word2Vec(sentences, size=50)  #训练skip-gram模型,默认window=5     model.save('mymodel') else:    model = word2vec.Word2Vec.load('mymodel')word_vectors = model.wvdel model复制代码
word_vectors复制代码
复制代码
data_t['vec'] = data_t.review.apply(lambda x :[word_vectors[w] for w in x.split() if w in word_vectors])复制代码
del data_t['review']del word_vectors复制代码
import gcgc.collect()复制代码
14复制代码
data_t = data_t[data_t['vec'].apply(lambda x:len(x)>0)]data_t.sentiment.value_counts()复制代码
0    124991    12495Name: sentiment, dtype: int64复制代码
maxlength = max([len(x) for x in data_t.vec])maxlength复制代码
1622复制代码
sum(data_t.vec.apply(len)>300)复制代码
3246复制代码
def pad(x):    if len(x)>300:        x1 = x[:300]    else:        x1 = np.zeros((300,50))        x1[:len(x)] = x    return x1复制代码
data_t['vec'] = data_t.vec.apply(pad)复制代码
import tensorflow as tf复制代码
/anaconda3/envs/py35/lib/python3.5/importlib/_bootstrap.py:222: RuntimeWarning: compiletime version 3.6 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.5  return f(*args, **kwds)/anaconda3/envs/py35/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.  from ._conv import register_converters as _register_converters复制代码
learning_rate = 0.002batch_size = 100n_input = 50n_steps = 300n_hidden = 300n_classes = 2复制代码
x = tf.placeholder(tf.float32, [None, n_steps,n_input])y = tf.placeholder(tf.int64, [None])keep_prob = tf.placeholder("float")复制代码
def length(shuru):    return tf.reduce_sum(tf.sign(tf.reduce_max(tf.abs(shuru),reduction_indices=2)),reduction_indices=1)复制代码
cell = tf.contrib.rnn.DropoutWrapper(tf.contrib.rnn.GRUCell(n_hidden),                output_keep_prob = keep_prob)复制代码
output, _ = tf.nn.dynamic_rnn(            cell,            x,            dtype=tf.float32,            sequence_length = length(x)        )复制代码
output.get_shape()复制代码
TensorShape([Dimension(None), Dimension(300), Dimension(300)])复制代码
index = tf.range(0,batch_size)*n_steps + (tf.cast(length(x),tf.int32) - 1)flat = tf.reshape(output,[-1,int(output.get_shape()[2])])last = tf.gather(flat,index)复制代码
weight = tf.Variable(tf.truncated_normal((n_hidden, n_classes), stddev=0.001))bias = tf.Variable(tf.constant(0.1, shape=[n_classes]))com_out = tf.matmul(last, weight) + biasprediction = tf.nn.softmax(com_out)复制代码
cross_entropy = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels = y, logits = com_out))复制代码
optimizer = tf.train.AdamOptimizer(learning_rate)grads = optimizer.compute_gradients(cross_entropy)for i, (g, v) in enumerate(grads):    if g is not None:        grads[i] = (tf.clip_by_norm(g, 5), v)  # clip gradientstrain_op = optimizer.apply_gradients(grads)复制代码
/anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/gradients_impl.py:97: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "WARNING:tensorflow:From /anaconda3/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/clip_ops.py:110: calling reduce_sum (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.Instructions for updating:keep_dims is deprecated, use keepdims instead复制代码
correct_pred = tf.equal(tf.argmax(prediction,1), y)accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))复制代码
def generatebatch(X,Y,n_examples, batch_size):    for batch_i in range(n_examples // batch_size):        start = batch_i*batch_size        end = start + batch_size        batch_xs = X[start:end]        batch_ys = Y[start:end]        yield batch_xs, batch_ys # 生成每一个batch复制代码
sess = tf.Session()init = tf.global_variables_initializer() sess.run(init)saver = tf.train.Saver()复制代码
for step in range(10):    index = np.random.permutation(int(len(data_t.vec.values)))    for batch_x,batch_y in generatebatch(data_t.vec.values[index],data_t.sentiment.values[index],len(data_t.vec.values),batch_size):         batch_x = np.concatenate(batch_x).reshape(batch_size,300,50)        batch_x.astype(np.float32)        sess.run(train_op, feed_dict={x: batch_x, y: batch_y,keep_prob: 0.5})    acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y,keep_prob: 1})    loss = sess.run(cross_entropy, feed_dict={x: batch_x, y: batch_y,keep_prob: 1})    saver.save(sess,'./lesson0',global_step = step)    print("Iter " + str(step) + ", Minibatch Loss= " +  "{}".format(loss) + ", Training Accuracy= " +  "{}".format(acc))print("Optimization Finished!")复制代码
Iter 0, Minibatch Loss= 0.3504045009613037, Training Accuracy= 0.8799999952316284Iter 1, Minibatch Loss= 0.2799288034439087, Training Accuracy= 0.8899999856948853Iter 2, Minibatch Loss= 0.25252586603164673, Training Accuracy= 0.8700000047683716Iter 3, Minibatch Loss= 0.2636661231517792, Training Accuracy= 0.9300000071525574复制代码

转载地址:http://uzhyx.baihongyu.com/

你可能感兴趣的文章
老话重谈 加密身份验证
查看>>
关于bacula网络备份软件的安装以及配置2
查看>>
MySQL中的安全更新模式
查看>>
关于完全卸载Office的一些记录
查看>>
DC学院数据分析学习笔记(四):爬虫的一些高级技巧
查看>>
Android实现自动更新功能
查看>>
运维的shel小编(4)
查看>>
搭建网站必不可少的知识8
查看>>
RHEL5.1安装 VM TOOL
查看>>
Perl开发的几个小注意事项
查看>>
SQL Server数据库备份恢复常见问题(不断更新中)
查看>>
实现hive proxy1-hive认证实现
查看>>
LinuxShell脚本之利用rsync+ssh实现Linux文件系统远程备份
查看>>
设计和使用维护计划
查看>>
Hyper-V 2016 系列教程3 Hyper-V 组件的添加
查看>>
func install in ubuntu-server
查看>>
PostgreSQL数据库pg_dump命令行不输入密码的方法
查看>>
asp教程八:访问数据库
查看>>
Linux 文件系统权限记序
查看>>
Exchange2010高可靠性和可用性解决方案
查看>>