Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

线性判别分析:基于NumPy实现的算法准确率不足的原因 #14

Open
WhizZest opened this issue Sep 2, 2024 · 0 comments
Open

Comments

@WhizZest
Copy link

WhizZest commented Sep 2, 2024

书中提到,基于NumPy实现的算法准确率只有0.85,而sklearn的准确率却达到1.0。
我检查了源码,找到了原因:
在代码清单5-1,计算协方差矩阵的calc_cov函数中,不应该对X和Y做标准化处理,只需要做中心化。
具体来说,就是做如下改动:
修改前:

# 数据标准化
X = (X - np.mean(X, axis=0))/np.std(X, axis=0)
Y = X if Y == None else (Y - np.mean(Y, axis=0))/np.std(Y, axis=0)

修改后:

# 数据中心化
X = X - np.mean(X, axis=0)
Y = X if Y == None else (Y - np.mean(Y, axis=0))

这样修改之后,准确率就达到1.0,和sklearn一样。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant