`

通过multiprocessing模块及时释放tensorflow的资源

阅读更多
在使用tf.data等模块时,tensorflow会产生内存泄露;当内存泄露发生时,我们期望及时保存checkpoint,返回相应的状态,然后重新启动tensorflow进行增量训练。

如果采用subprocess.call()方案在子进程中调用tensorflow,需要自行实现参数、结果的序列化和反序列化,比较麻烦。
本文给出一种通过multiprocessing模块在子进程中调用tensorflow的实现,传参数so easy
话不多说,上代码:
# coding=utf-8
'''
Created on Sep 18, 2018

@author: colinliang
'''
from __future__ import absolute_import, division, print_function
def run_tf(args, queue=None):
    print('\n\n------- beginning of tf process')
    print('args for tf: %s' % args)
    import tensorflow as tf
    sess = tf.Session()
    
    import psutil
    mem_start=psutil.virtual_memory().available
    batch=2000000
    n=(args['epoch']+1) *batch
    with tf.device('/cpu:0'):
        v = tf.get_variable(name="tf_var", shape=[n], dtype=tf.float32, initializer=tf.random_uniform_initializer(-1, 1, 0, dtype=tf.float32))
    sess.run(tf.global_variables_initializer())
#     print( (mem_start-psutil.virtual_memory().available)/batch)
    if(mem_start-psutil.virtual_memory().available  >batch*12): #内存检测,有内存泄露时及时退出
        result={'exit code':-1}
        if(queue is not None):
            queue.put(result)
        return result
    
    import time 
    time.sleep(10)
    
    r = sess.run(v[0])
    print('sess: %s' % sess)
    sess.close()
#     tf.reset_default_graph()
    result={'first elem of tf var':r}
    if(queue is not None):
        queue.put(result)
    print('------- end of tf process')
    return result

#####################################################
from  multiprocessing import Process, Queue
# 参考自https://stackoverflow.com/questions/39758094/clearing-tensorflow-gpu-memory-after-model-execution
for i in range(5):  # Process的使用方法 https://docs.python.org/2/library/multiprocessing.html
    q = Queue()
    args = {'epoch':i}
    p = Process(target=run_tf, args=(args, q))
    p.start()
    p.join()   
    print("result: %s" % q.get())

分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics