-
Notifications
You must be signed in to change notification settings - Fork 0
/
tensorflow_create_tfconfig.py
43 lines (32 loc) · 1.03 KB
/
tensorflow_create_tfconfig.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import numpy as np
import os, sys
import json
def get_job_node_list_slurm_rwth():
host_list_val = eval(os.environ['R_WLM_ABAQUSHOSTLIST'])
host_list = []
for x in host_list_val:
host_list.append(x[0])
host_list = list(set(host_list))
return host_list
def build_tf_config():
# general settings
port_range_start = 23456
tasks_per_node = int(os.environ['SLURM_NTASKS_PER_NODE'])
# create worker list
list_hosts = sorted(get_job_node_list_slurm_rwth())
list_workers = []
for host in list_hosts:
for i in range(tasks_per_node):
list_workers.append(f"{host}:{port_range_start+i}")
# create config and set environment variable
tf_config = {
'cluster': {
'worker': list_workers
},
'task': {'type': 'worker', 'index': int(os.environ['RANK'])}
}
str_dump = json.dumps(tf_config)
print(str_dump)
if __name__ == '__main__':
# actual building the config
build_tf_config()