-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgpu_setup.py
More file actions
38 lines (31 loc) · 1.38 KB
/
gpu_setup.py
File metadata and controls
38 lines (31 loc) · 1.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""GPU setup."""
import tensorflow as tf
from tensorflow.config.experimental import VirtualDeviceConfiguration
def create_distribute(
vgpus=1, memory_limit=512, gpu_idx=0, do_cpu=False, gpus=None):
"""Create tf.distribute.strategy."""
if do_cpu:
cpus = tf.config.experimental.list_physical_devices('CPU')
return tf.distribute.MirroredStrategy(devices=["/CPU:0"])
if gpus is not None:
gpu_all = tf.config.list_physical_devices('GPU')
gpu_subset = [gpu_all[int(idx)] for idx in gpus.split(',')]
tf.config.set_visible_devices(
tf.config.list_physical_devices('CPU') + gpu_subset)
gpus = tf.config.get_visible_devices('GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
if vgpus > 1:
tf.config.experimental.set_virtual_device_configuration(
gpus[gpu_idx], [
VirtualDeviceConfiguration(memory_limit=memory_limit)
for _ in range(vgpus)])
print("Created {} Virtual GPUs:".format(vgpus))
vgpu_list = tf.config.experimental.list_logical_devices('GPU')
for i, d in enumerate(vgpu_list):
print(" <{}> {}".format(i, str(d)))
else:
print("Using {} GPUs:".format(len(gpus)))
for i, d in enumerate(gpus):
print(" <{}> {}".format(i, str(d)))
return tf.distribute.MirroredStrategy()