Skip to content

Commit

Permalink
refactor to listen fxn, prep for threading
Browse files Browse the repository at this point in the history
  • Loading branch information
jtwhite79 committed Nov 3, 2024
1 parent 126bcfd commit 91a0f99
Showing 1 changed file with 116 additions and 43 deletions.
159 changes: 116 additions & 43 deletions pyemu/utils/os_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,20 +517,44 @@ def __init__(self,timeout=0.1):
# while total < num_bytes:
# data += s.recv()

def nonblocking_recv(self,s,msg_len):
try:
msg = s.recv(msg_len)
except socket.timeout as e:
emess = e.args[0]
if emess == 'timed out':
return None
else:
raise Exception(e)
except socket.error as e:
# Something else happened, handle error, exit, etc.
raise Exception(e)
else:
if len(msg) == 0:
return None
else:
return msg
# got a message do something :)

def recv(self,s,dtype=None):
recv_sec_message = None
while True:
data = s.recv(len(self.sec_message_buf))
if len(data) > 0:
recv_sec_message = [int(d) for d in data]
break
time.sleep(self.timeout)

#data = s.recv(len(self.sec_message_buf))
data = self.nonblocking_recv(s,len(self.sec_message_buf))

#if len(data) == 0:
# return 0
if data is None:
return 0

recv_sec_message = [int(d) for d in data]
self._check_sec_message(recv_sec_message)
while True:
data = s.recv(self.header_size)
if len(data) > 0:
break
time.sleep(self.timeout)
#data = s.recv(self.header_size)
#if len(data) == 0:
# return -1
data = self.nonblocking_recv(s,self.header_size)
if data is None:
raise Exception("didnt recv header after security message")
self.buf_size = int.from_bytes(data[self.buf_idx[0]:self.buf_idx[1]], "little")
self.mtype = int.from_bytes(data[self.type_idx[0]:self.type_idx[1]], "little")
self.group = int.from_bytes(data[self.group_idx[0]:self.group_idx[1]], "little")
Expand All @@ -539,9 +563,14 @@ def recv(self,s,dtype=None):
data_len = self.buf_size - self.data_idx[0] - 1
self.data_pak = None
if data_len > 0:
self.data_pak = self.deserialize_data(s.recv(data_len),dtype)
#self.data_pak = data[self.data_idx[0]:]
print()
raw_data = self.nonblocking_recv(s,data_len)
if raw_data is None:
raise Exception("didnt recv data pack after header of {0} bytes".format(data_len))
if dtype is None and self.mtype == 10:
dtype = float
self.data_pak = self.deserialize_data(raw_data,dtype)
return len(data) + data_len


def deserialize_data(self,data,dtype=None):
if dtype is not None and dtype in [int,float]:
Expand Down Expand Up @@ -582,7 +611,6 @@ def send(self,s,mtype,group,runid,desc,data):
buf += full_desc.encode()
buf += sdata
s.send(buf)
print()


def _check_sec_message(self,recv_sec_message):
Expand All @@ -593,7 +621,7 @@ def _check_sec_message(self,recv_sec_message):
class PyPestWorker(object):


def __init__(self, pst, host, port, timeout=0.1):
def __init__(self, pst, host, port, timeout=0.25):
self.host = host
self.port = port
self._pst_arg = pst
Expand All @@ -605,7 +633,11 @@ def __init__(self, pst, host, port, timeout=0.1):
self.par_names = None
self.obs_names = None

self.par_values = None

self._process_pst()
self.connect()


def _process_pst(self):
if isinstance(self._pst_arg,str):
Expand All @@ -618,66 +650,107 @@ def _process_pst(self):


def connect(self):
print("trying to connect to {0}:{1}...".format(self.host,self.port))
self.s = None
c = 0
while True:
try:
time.sleep(self.timeout)
print(".", end='')
c += 1
if c % 75 == 0:
print('')
self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.s.connect((self.host, self.port))
print("-->connected to {0}:{1}".format(self.host,self.port))
break
time.sleep(self.timeout)

except ConnectionRefusedError:
continue
except Exception as e:
continue

def recv(self,dtype=None):
self.net_pack.recv(self.s,dtype=dtype)
print("recv'd message type:", NetPack.netpack_type[self.net_pack.mtype])
n = self.net_pack.recv(self.s,dtype=dtype)
if n > 0:
print("recv'd message type:", NetPack.netpack_type[self.net_pack.mtype])
return n


def send(self,mtype,group,runid,desc="",data=0):
self.net_pack.send(self.s,mtype,group,runid,desc,data)
print("sent message type:", NetPack.netpack_type[mtype])

def initialize(self):
self.connect()
def listen(self):
self.s.settimeout(self.timeout)
while True:
time.sleep(self.timeout)

n = self.recv()
if n > 0:
# need to sync here
if self.net_pack.mtype == 10:
self.par_values = self.net_pack.data_pak.copy()

#request for cwd
self.recv()
# request cwd
elif self.net_pack.mtype == 4:
self.send(mtype=5, group=self.net_pack.group,
runid=self.net_pack.runid,
desc="sending cwd", data=os.getcwd())

if self.net_pack.mtype != 4:
raise Exception("unexpected net pack type, should be {0}, not {1}".\
format(NetPack.netpack_type[4],
NetPack.netpack_type[self.net_pack.mtype]))
self.send(mtype=5,group=self.net_pack.group,
runid=self.net_pack.runid,desc="sending cwd",data=os.getcwd())
elif self.net_pack.mtype == 8:
self.par_names = self.net_pack.data_pak

# par names
self.recv()
# todo: check revc'd mtype
self.par_names = self.net_pack.data_pak.copy()
elif self.net_pack.mtype == 9:
self.obs_names = self.net_pack.data_pak

# obs nanes
self.recv()
# todo check recv'd mtype
self.obs_names = self.net_pack.data_pak.copy()
elif self.net_pack.mtype == 6:
self.send(7, self.net_pack.group,
self.net_pack.runid,
"fake linpack result", data=1)

# lin pack request
self.recv()
elif self.net_pack.mtype == 15:
self.send(15, self.net_pack.group,
self.net_pack.runid,
"ping back")


self.send(7,self.net_pack.group,self.net_pack.runid,"fake linpack result",data=1)

print()
# def initialize(self):
# self.connect()
#
# #request for cwd
# self.recv()
#
# if self.net_pack.mtype != 4:
# raise Exception("unexpected net pack type, should be {0}, not {1}".\
# format(NetPack.netpack_type[4],
# NetPack.netpack_type[self.net_pack.mtype]))
# self.send(mtype=5,group=self.net_pack.group,
# runid=self.net_pack.runid,desc="sending cwd",data=os.getcwd())
#
# # par names
# self.recv()
# # todo: check revc'd mtype
# self.par_names = self.net_pack.data_pak.copy()
#
# # obs nanes
# self.recv()
# # todo check recv'd mtype
# self.obs_names = self.net_pack.data_pak.copy()
#
# # lin pack request
# self.recv()
#
# self.send(7,self.net_pack.group,self.net_pack.runid,"fake linpack result",data=1)
#



if __name__ == "__main__":
host = "localhost"
port = 4004

ppw = PyPestWorker(None,host,port)
ppw.initialize()
#ppw.initialize()


0 comments on commit 91a0f99

Please sign in to comment.