From 105f42935332d0fdfbb50a0063d1eefcfaf65d68 Mon Sep 17 00:00:00 2001 From: deanlee Date: Thu, 1 Aug 2024 19:30:22 +0800 Subject: [PATCH] refactor update update test case improve update() --- opendbc/can/common.h | 9 +-- opendbc/can/common.pxd | 17 ++--- opendbc/can/common_dbc.h | 8 --- opendbc/can/parser.cc | 74 +++++++++------------ opendbc/can/parser_pyx.pyx | 87 ++++++++++++------------- opendbc/can/tests/test_packer_parser.py | 7 +- 6 files changed, 90 insertions(+), 112 deletions(-) diff --git a/opendbc/can/common.h b/opendbc/can/common.h index 8b5b8d320d..b79bae6aad 100644 --- a/opendbc/can/common.h +++ b/opendbc/can/common.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -65,7 +66,6 @@ class CANParser { private: const int bus; const DBC *dbc = NULL; - std::unordered_map message_states; public: bool can_valid = false; @@ -75,15 +75,16 @@ class CANParser { uint64_t last_nonempty_nanos = 0; uint64_t bus_timeout_threshold = 0; uint64_t can_invalid_cnt = CAN_INVALID_CNT; + std::unordered_map message_states; CANParser(int abus, const std::string& dbc_name, const std::vector> &messages); CANParser(int abus, const std::string& dbc_name, bool ignore_checksum, bool ignore_counter); - void update(const std::vector &can_data, std::vector &vals); - void query_latest(std::vector &vals, uint64_t last_ts = 0); + std::set update(const std::vector &can_data); protected: - void UpdateCans(const CanData &can); + void clearAllValues(); + void updateCans(const CanData &can, std::set &updated_addresses); void UpdateValid(uint64_t nanos); }; diff --git a/opendbc/can/common.pxd b/opendbc/can/common.pxd index 21e276fa07..025f8a3c8d 100644 --- a/opendbc/can/common.pxd +++ b/opendbc/can/common.pxd @@ -4,6 +4,7 @@ from libc.stdint cimport uint8_t, uint32_t, uint64_t from libcpp cimport bool from libcpp.pair cimport pair +from libcpp.set cimport set from libcpp.string cimport string from libcpp.vector cimport vector from libcpp.unordered_map cimport unordered_map @@ -52,13 +53,6 @@ cdef extern from "common_dbc.h": unordered_map[uint32_t, const Msg*] addr_to_msg unordered_map[string, const Msg*] name_to_msg - cdef struct SignalValue: - uint32_t address - uint64_t ts_nanos - string name - double value - vector[double] all_values - cdef struct SignalPackValue: string name double value @@ -76,11 +70,18 @@ cdef extern from "common.h": uint64_t nanos vector[CanFrame] frames + cdef cppclass MessageState: + vector[Signal] parse_sigs + vector[double] vals + vector[vector[double]] all_vals + uint64_t last_seen_nanos + cdef cppclass CANParser: bool can_valid bool bus_timeout + unordered_map[uint32_t, MessageState] message_states CANParser(int, string, vector[pair[uint32_t, int]]) except + - void update(vector[CanData]&, vector[SignalValue]&) except + + set[uint32_t] update(vector[CanData]&) except + cdef cppclass CANPacker: CANPacker(string) diff --git a/opendbc/can/common_dbc.h b/opendbc/can/common_dbc.h index 19507ecd4e..2b99dad832 100644 --- a/opendbc/can/common_dbc.h +++ b/opendbc/can/common_dbc.h @@ -10,14 +10,6 @@ struct SignalPackValue { double value; }; -struct SignalValue { - uint32_t address; - uint64_t ts_nanos; - std::string name; - double value; // latest value - std::vector all_values; // all values from this cycle -}; - enum SignalType { DEFAULT, COUNTER, diff --git a/opendbc/can/parser.cc b/opendbc/can/parser.cc index a65a8ec2d6..d766c9534a 100644 --- a/opendbc/can/parser.cc +++ b/opendbc/can/parser.cc @@ -157,34 +157,44 @@ CANParser::CANParser(int abus, const std::string& dbc_name, bool ignore_checksum } } -void CANParser::update(const std::vector &can_data, std::vector &vals) { - uint64_t current_nanos = 0; +std::set CANParser::update(const std::vector &can_data) { + clearAllValues(); + std::set updated_addresses; + + if (can_data.empty()) { + return updated_addresses; + } + + if (first_nanos == 0) { + first_nanos = can_data.front().nanos; + } + for (const auto &c : can_data) { - if (first_nanos == 0) { - first_nanos = c.nanos; - } - if (current_nanos == 0) { - current_nanos = c.nanos; - } last_nanos = c.nanos; + updateCans(c, updated_addresses); + } + UpdateValid(last_nanos); + + return updated_addresses; +} - UpdateCans(c); - UpdateValid(last_nanos); +void CANParser::clearAllValues() { + for (auto &[_, state] : message_states) { + for (auto &vals : state.all_vals) { + vals.clear(); + } } - query_latest(vals, current_nanos); } -void CANParser::UpdateCans(const CanData &can) { +void CANParser::updateCans(const CanData &can, std::set &updated_addresses) { //DEBUG("got %zu messages\n", can.frames.size()); - bool bus_empty = true; - for (const auto &frame : can.frames) { if (frame.src != bus) { // DEBUG("skip %d: wrong bus\n", cmsg.getAddress()); continue; } - bus_empty = false; + last_nonempty_nanos = can.nanos; auto state_it = message_states.find(frame.address); if (state_it == message_states.end()) { @@ -202,14 +212,10 @@ void CANParser::UpdateCans(const CanData &can) { // continue; //} - state_it->second.parse(can.nanos, frame.dat); - } - - // update bus timeout - if (!bus_empty) { - last_nonempty_nanos = can.nanos; + if (state_it->second.parse(can.nanos, frame.dat)) { + updated_addresses.insert(frame.address); + } } - bus_timeout = (can.nanos - last_nonempty_nanos) > bus_timeout_threshold; } void CANParser::UpdateValid(uint64_t nanos) { @@ -239,27 +245,5 @@ void CANParser::UpdateValid(uint64_t nanos) { } can_invalid_cnt = _valid ? 0 : (can_invalid_cnt + 1); can_valid = (can_invalid_cnt < CAN_INVALID_CNT) && _counters_valid; -} - -void CANParser::query_latest(std::vector &vals, uint64_t last_ts) { - if (last_ts == 0) { - last_ts = last_nanos; - } - for (auto& kv : message_states) { - auto& state = kv.second; - if (last_ts != 0 && state.last_seen_nanos < last_ts) { - continue; - } - - for (int i = 0; i < state.parse_sigs.size(); i++) { - const Signal &sig = state.parse_sigs[i]; - SignalValue &v = vals.emplace_back(); - v.address = state.address; - v.ts_nanos = state.last_seen_nanos; - v.name = sig.name; - v.value = state.vals[i]; - v.all_values = state.all_vals[i]; - state.all_vals[i].clear(); - } - } + bus_timeout = (nanos - last_nonempty_nanos) > bus_timeout_threshold; } diff --git a/opendbc/can/parser_pyx.pyx b/opendbc/can/parser_pyx.pyx index 9fb0c9f021..9d0858b6a4 100644 --- a/opendbc/can/parser_pyx.pyx +++ b/opendbc/can/parser_pyx.pyx @@ -1,14 +1,14 @@ # distutils: language = c++ # cython: c_string_encoding=ascii, language_level=3 -from cython.operator cimport dereference as deref, preincrement as preinc from libcpp.pair cimport pair +from libcpp.set cimport set from libcpp.string cimport string from libcpp.vector cimport vector from libc.stdint cimport uint32_t from .common cimport CANParser as cpp_CANParser -from .common cimport dbc_lookup, SignalValue, DBC, CanData, CanFrame +from .common cimport dbc_lookup, DBC, CanData, CanFrame import numbers from collections import defaultdict @@ -18,15 +18,17 @@ cdef class CANParser: cdef: cpp_CANParser *can const DBC *dbc - vector[uint32_t] addresses + set[uint32_t] addresses cdef readonly: dict vl dict vl_all dict ts_nanos string dbc_name + int bus def __init__(self, dbc_name, messages, bus=0): + self.bus = bus self.dbc_name = dbc_name self.dbc = dbc_lookup(dbc_name) if not self.dbc: @@ -47,18 +49,11 @@ cdef class CANParser: address = m.address message_v.push_back((address, c[1])) - self.addresses.push_back(address) - - name = m.name.decode("utf8") - self.vl[address] = {} - self.vl[name] = self.vl[address] - self.vl_all[address] = defaultdict(list) - self.vl_all[name] = self.vl_all[address] - self.ts_nanos[address] = {} - self.ts_nanos[name] = self.ts_nanos[address] + self.addresses.insert(address) self.can = new cpp_CANParser(bus, dbc_name, message_v) - self.update_strings([]) + self._update_value_dicts(self.addresses) + self._map_dicts_by_name_to_address() def __dealloc__(self): if self.can: @@ -68,16 +63,7 @@ cdef class CANParser: # input format: # [nanos, [[address, data, src], ...]] # [[nanos, [[address, data, src], ...], ...]] - for address in self.addresses: - self.vl_all[address].clear() - cur_address = -1 - vl = {} - vl_all = {} - ts_nanos = {} - updated_addrs = set() - - cdef vector[SignalValue] new_vals cdef CanFrame* frame cdef CanData* can_data cdef vector[CanData] can_data_array @@ -91,7 +77,7 @@ cdef class CANParser: can_data = &(can_data_array.emplace_back()) can_data.nanos = s[0] can_data.frames.reserve(len(s[1])) - for f in s[1]: + for f in (f for f in s[1] if f[2] == self.bus): frame = &(can_data.frames.emplace_back()) frame.address = f[0] frame.dat = f[1] @@ -99,29 +85,38 @@ cdef class CANParser: except TypeError: raise RuntimeError("invalid parameter") - self.can.update(can_data_array, new_vals) - - cdef vector[SignalValue].iterator it = new_vals.begin() - cdef SignalValue* cv - while it != new_vals.end(): - cv = &deref(it) - - # Check if the address has changed - if cv.address != cur_address: - cur_address = cv.address - vl = self.vl[cur_address] - vl_all = self.vl_all[cur_address] - ts_nanos = self.ts_nanos[cur_address] - updated_addrs.add(cur_address) - - # Cast char * directly to unicode - cv_name = cv.name - vl[cv_name] = cv.value - vl_all[cv_name] = cv.all_values - ts_nanos[cv_name] = cv.ts_nanos - preinc(it) - - return updated_addrs + updated_addresses = self.can.update(can_data_array) + self._update_value_dicts(updated_addresses) + return updated_addresses + + cdef _update_value_dicts(self, set[uint32_t] &addrs): + # Iterate over the set of updated message addresses + for addr in addrs: + # Ensure the address exists in vl, vl_all, and ts_nanos, initializing if necessary + vl = self.vl.setdefault(addr, {}) + vl_all = self.vl_all.setdefault(addr, defaultdict(list)) + ts_nanos = self.ts_nanos.setdefault(addr, {}) + + # Iterate over the signals in the message state + state = &self.can.message_states.at(addr) + for i in range(state.parse_sigs.size()): + sig_name = state.parse_sigs[i].name + vl[sig_name] = state.vals[i] + vl_all[sig_name] = state.all_vals[i] + ts_nanos[sig_name] = state.last_seen_nanos + + # Clear vl_all for addresses not in the updated set + for addr in self.addresses: + if addrs.count(addr) == 0: + self.vl_all[addr].clear() + + cdef _map_dicts_by_name_to_address(self): + for address in self.addresses: + msg = self.dbc.addr_to_msg.at(address) + name = msg.name + self.vl[name] = self.vl[address] + self.vl_all[name] = self.vl_all[address] + self.ts_nanos[name] = self.ts_nanos[address] @property def can_valid(self): diff --git a/opendbc/can/tests/test_packer_parser.py b/opendbc/can/tests/test_packer_parser.py index acb5acebe9..f61db18961 100644 --- a/opendbc/can/tests/test_packer_parser.py +++ b/opendbc/can/tests/test_packer_parser.py @@ -126,6 +126,11 @@ def rx_steering_msg(values, bad_checksum=False): assert parser.vl["STEERING_CONTROL"]["STEER_TORQUE"] == 300 assert parser.vl_all["STEERING_CONTROL"]["STEER_TORQUE"] == [300] + def test_parser_empty_message(self): + parser = CANParser("toyota_nodsu_pt_generated", [("ACC_CONTROL", 0)]) + addr = parser.update_strings([]) + assert len(addr) == 0 + def test_packer_parser(self): msgs = [ ("Brake_Status", 0), @@ -264,7 +269,7 @@ def test_updated(self): can_msgs[frame].append(packer.make_can_msg("VSA_STATUS", 0, values)) idx += 1 - parser.update_strings([[0, m] for m in can_msgs]) + parser.update_strings([[random.randint(0, 255), m] for m in can_msgs]) vl_all = parser.vl_all["VSA_STATUS"]["USER_BRAKE"] assert vl_all == user_brake_vals