From 91a0f99d64f68d8a47c7571f78ce20ed17f7e6d4 Mon Sep 17 00:00:00 2001 From: jwhite Date: Sun, 3 Nov 2024 10:11:27 -0600 Subject: [PATCH] refactor to listen fxn, prep for threading --- pyemu/utils/os_utils.py | 159 +++++++++++++++++++++++++++++----------- 1 file changed, 116 insertions(+), 43 deletions(-) diff --git a/pyemu/utils/os_utils.py b/pyemu/utils/os_utils.py index c4f542e6..b7dd9126 100644 --- a/pyemu/utils/os_utils.py +++ b/pyemu/utils/os_utils.py @@ -517,20 +517,44 @@ def __init__(self,timeout=0.1): # while total < num_bytes: # data += s.recv() + def nonblocking_recv(self,s,msg_len): + try: + msg = s.recv(msg_len) + except socket.timeout as e: + emess = e.args[0] + if emess == 'timed out': + return None + else: + raise Exception(e) + except socket.error as e: + # Something else happened, handle error, exit, etc. + raise Exception(e) + else: + if len(msg) == 0: + return None + else: + return msg + # got a message do something :) + def recv(self,s,dtype=None): recv_sec_message = None - while True: - data = s.recv(len(self.sec_message_buf)) - if len(data) > 0: - recv_sec_message = [int(d) for d in data] - break - time.sleep(self.timeout) + + #data = s.recv(len(self.sec_message_buf)) + data = self.nonblocking_recv(s,len(self.sec_message_buf)) + + #if len(data) == 0: + # return 0 + if data is None: + return 0 + + recv_sec_message = [int(d) for d in data] self._check_sec_message(recv_sec_message) - while True: - data = s.recv(self.header_size) - if len(data) > 0: - break - time.sleep(self.timeout) + #data = s.recv(self.header_size) + #if len(data) == 0: + # return -1 + data = self.nonblocking_recv(s,self.header_size) + if data is None: + raise Exception("didnt recv header after security message") self.buf_size = int.from_bytes(data[self.buf_idx[0]:self.buf_idx[1]], "little") self.mtype = int.from_bytes(data[self.type_idx[0]:self.type_idx[1]], "little") self.group = int.from_bytes(data[self.group_idx[0]:self.group_idx[1]], "little") @@ -539,9 +563,14 @@ def recv(self,s,dtype=None): data_len = self.buf_size - self.data_idx[0] - 1 self.data_pak = None if data_len > 0: - self.data_pak = self.deserialize_data(s.recv(data_len),dtype) - #self.data_pak = data[self.data_idx[0]:] - print() + raw_data = self.nonblocking_recv(s,data_len) + if raw_data is None: + raise Exception("didnt recv data pack after header of {0} bytes".format(data_len)) + if dtype is None and self.mtype == 10: + dtype = float + self.data_pak = self.deserialize_data(raw_data,dtype) + return len(data) + data_len + def deserialize_data(self,data,dtype=None): if dtype is not None and dtype in [int,float]: @@ -582,7 +611,6 @@ def send(self,s,mtype,group,runid,desc,data): buf += full_desc.encode() buf += sdata s.send(buf) - print() def _check_sec_message(self,recv_sec_message): @@ -593,7 +621,7 @@ def _check_sec_message(self,recv_sec_message): class PyPestWorker(object): - def __init__(self, pst, host, port, timeout=0.1): + def __init__(self, pst, host, port, timeout=0.25): self.host = host self.port = port self._pst_arg = pst @@ -605,7 +633,11 @@ def __init__(self, pst, host, port, timeout=0.1): self.par_names = None self.obs_names = None + self.par_values = None + self._process_pst() + self.connect() + def _process_pst(self): if isinstance(self._pst_arg,str): @@ -618,66 +650,107 @@ def _process_pst(self): def connect(self): + print("trying to connect to {0}:{1}...".format(self.host,self.port)) self.s = None + c = 0 while True: try: + time.sleep(self.timeout) + print(".", end='') + c += 1 + if c % 75 == 0: + print('') self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.s.connect((self.host, self.port)) print("-->connected to {0}:{1}".format(self.host,self.port)) break - time.sleep(self.timeout) + except ConnectionRefusedError: continue except Exception as e: continue def recv(self,dtype=None): - self.net_pack.recv(self.s,dtype=dtype) - print("recv'd message type:", NetPack.netpack_type[self.net_pack.mtype]) + n = self.net_pack.recv(self.s,dtype=dtype) + if n > 0: + print("recv'd message type:", NetPack.netpack_type[self.net_pack.mtype]) + return n def send(self,mtype,group,runid,desc="",data=0): self.net_pack.send(self.s,mtype,group,runid,desc,data) print("sent message type:", NetPack.netpack_type[mtype]) - def initialize(self): - self.connect() + def listen(self): + self.s.settimeout(self.timeout) + while True: + time.sleep(self.timeout) + + n = self.recv() + if n > 0: + # need to sync here + if self.net_pack.mtype == 10: + self.par_values = self.net_pack.data_pak.copy() - #request for cwd - self.recv() + # request cwd + elif self.net_pack.mtype == 4: + self.send(mtype=5, group=self.net_pack.group, + runid=self.net_pack.runid, + desc="sending cwd", data=os.getcwd()) - if self.net_pack.mtype != 4: - raise Exception("unexpected net pack type, should be {0}, not {1}".\ - format(NetPack.netpack_type[4], - NetPack.netpack_type[self.net_pack.mtype])) - self.send(mtype=5,group=self.net_pack.group, - runid=self.net_pack.runid,desc="sending cwd",data=os.getcwd()) + elif self.net_pack.mtype == 8: + self.par_names = self.net_pack.data_pak - # par names - self.recv() - # todo: check revc'd mtype - self.par_names = self.net_pack.data_pak.copy() + elif self.net_pack.mtype == 9: + self.obs_names = self.net_pack.data_pak - # obs nanes - self.recv() - # todo check recv'd mtype - self.obs_names = self.net_pack.data_pak.copy() + elif self.net_pack.mtype == 6: + self.send(7, self.net_pack.group, + self.net_pack.runid, + "fake linpack result", data=1) - # lin pack request - self.recv() + elif self.net_pack.mtype == 15: + self.send(15, self.net_pack.group, + self.net_pack.runid, + "ping back") - self.send(7,self.net_pack.group,self.net_pack.runid,"fake linpack result",data=1) - print() + # def initialize(self): + # self.connect() + # + # #request for cwd + # self.recv() + # + # if self.net_pack.mtype != 4: + # raise Exception("unexpected net pack type, should be {0}, not {1}".\ + # format(NetPack.netpack_type[4], + # NetPack.netpack_type[self.net_pack.mtype])) + # self.send(mtype=5,group=self.net_pack.group, + # runid=self.net_pack.runid,desc="sending cwd",data=os.getcwd()) + # + # # par names + # self.recv() + # # todo: check revc'd mtype + # self.par_names = self.net_pack.data_pak.copy() + # + # # obs nanes + # self.recv() + # # todo check recv'd mtype + # self.obs_names = self.net_pack.data_pak.copy() + # + # # lin pack request + # self.recv() + # + # self.send(7,self.net_pack.group,self.net_pack.runid,"fake linpack result",data=1) + # if __name__ == "__main__": host = "localhost" port = 4004 - ppw = PyPestWorker(None,host,port) - ppw.initialize() + #ppw.initialize()