import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from matplotlib.collections import PatchCollection, LineCollection
# ============================================================
# OLS 示意图:采用“手工投影 + 2D 绘图”的方式
# 说明:
# 1. 不使用 matplotlib 的 mplot3d
# 2. 这样更容易画出类似教材中的干净插图风格
# 3. 代码中的参数(视角、平面大小、点的位置)都可以继续微调
# ============================================================
# ------------------------------------------------------------
# 1. 全局设置
# ------------------------------------------------------------
plt.rcParams["font.family"] = "serif"
plt.rcParams["mathtext.fontset"] = "cm"
# ------------------------------------------------------------
# 2. 定义“屏幕投影基向量”
# 解释:
# - e1 决定 X1 轴在屏幕上的方向
# - e2 决定 X2 轴在屏幕上的方向
# - ey 决定 Y 轴在屏幕上的方向
# 这三个向量不是三维空间里的标准基,而是“投影到屏幕后的方向”
# ------------------------------------------------------------
e1 = np.array([1.45, 0.28]) # X1:向右下方延伸
e2 = np.array([-1.00, 0.62]) # X2:向左上方延伸
ey = np.array([0.00, 1.22]) # Y :基本竖直向上
# ------------------------------------------------------------
# 3. 定义投影函数
# 输入:
# x1, x2, y 可以是标量,也可以是 numpy 数组
# 输出:
# 屏幕上的二维坐标 (u, v)
# ------------------------------------------------------------
def proj(x1, x2, y):
u = e1[0] * x1 + e2[0] * x2 + ey[0] * y
v = e1[1] * x1 + e2[1] * x2 + ey[1] * y
return np.column_stack([u, v])
# ------------------------------------------------------------
# 4. 定义回归平面
# y_hat = b0 + b1*x1 + b2*x2
# 这里的系数可以调,以改变平面倾斜程度
# ------------------------------------------------------------
b0 = 0.40
b1 = 0.23
b2 = 0.36
def yhat(x1, x2):
return b0 + b1 * x1 + b2 * x2
# ------------------------------------------------------------
# 5. 构造平面网格
# 为了让网格看起来均匀、像教材插图,
# 这里直接构造规则网格,然后逐个小格填色
# ------------------------------------------------------------
x1_min, x1_max = 0.0, 8.0
x2_min, x2_max = 0.0, 8.0
nx, ny = 18, 18
x1g = np.linspace(x1_min, x1_max, nx + 1)
x2g = np.linspace(x2_min, x2_max, ny + 1)
# ------------------------------------------------------------
# 6. 构造平面小方块(patches)
# 每个小方块:
# - 在 3D 中有 4 个角点
# - 经过投影后,在 2D 中变成一个四边形
# - 用小方块中心处的 y_hat 决定颜色
# ------------------------------------------------------------
patches = []
face_values = []
for i in range(nx):
for j in range(ny):
xa, xb = x1g[i], x1g[i + 1]
ya, yb = x2g[j], x2g[j + 1]
# 四个角点在平面上的高度
z11 = yhat(xa, ya)
z21 = yhat(xb, ya)
z22 = yhat(xb, yb)
z12 = yhat(xa, yb)
# 投影到二维平面
poly2d = proj(
np.array([xa, xb, xb, xa]),
np.array([ya, ya, yb, yb]),
np.array([z11, z21, z22, z12])
)
patches.append(Polygon(poly2d, closed=True))
# 用中心点的高度作为该小格的颜色值
xc = 0.5 * (xa + xb)
yc = 0.5 * (ya + yb)
face_values.append(yhat(xc, yc))
face_values = np.array(face_values)
# 归一化颜色值
zmin = yhat(x1_min, x2_min)
zmax = yhat(x1_max, x2_max)
face_values_norm = (face_values - zmin) / (zmax - zmin)
# ------------------------------------------------------------
# 7. 构造网格线
# 为了得到类似书中的黑色网格,需要把横纵网格线单独画出来
# ------------------------------------------------------------
grid_segments = []
# x1 方向网格线:固定 x2,改变 x1
for y0 in x2g:
x_line = np.linspace(x1_min, x1_max, 200)
y_line = np.full_like(x_line, y0)
z_line = yhat(x_line, y_line)
p = proj(x_line, y_line, z_line)
segs = np.stack([p[:-1], p[1:]], axis=1)
grid_segments.extend(segs)
# x2 方向网格线:固定 x1,改变 x2
for x0 in x1g:
y_line = np.linspace(x2_min, x2_max, 200)
x_line = np.full_like(y_line, x0)
z_line = yhat(x_line, y_line)
p = proj(x_line, y_line, z_line)
segs = np.stack([p[:-1], p[1:]], axis=1)
grid_segments.extend(segs)
# ------------------------------------------------------------
# 8. 手工指定观测点
# 这里不再随机生成,而是手工布点。
# 原因:
# - 教材中的图往往是“示意性”的
# - 手工布点更容易得到美观结果
# ------------------------------------------------------------
pts = np.array([
[0.9, 1.2, -0.90],
[1.4, 6.8, 0.55],
[2.7, 6.0, 0.35],
[4.0, 7.4, 0.70],
[5.7, 6.8, 0.85],
[7.0, 4.8, 0.45],
[7.2, 2.3, -0.75],
[6.0, 1.5, -0.55],
[3.2, 2.2, -0.45],
[2.2, 7.8, 0.90],
])
x1p = pts[:, 0]
x2p = pts[:, 1]
eps = pts[:, 2]
y_fit = yhat(x1p, x2p)
y_obs = y_fit + eps
# 平面上的点(残差线起点)
p_fit = proj(x1p, x2p, y_fit)
# 观测点(残差线终点)
p_obs = proj(x1p, x2p, y_obs)
# ------------------------------------------------------------
# 9. 坐标轴
# 手动画三根轴,避免默认三维坐标轴的杂乱感
# ------------------------------------------------------------
origin = proj(0.0, 0.0, 0.0)[0]
x1_end = proj(9.2, 0.0, 0.0)[0]
x2_end = proj(0.0, 9.0, 0.0)[0]
y_end = proj(0.0, 0.0, 7.2)[0]
# ------------------------------------------------------------
# 10. 开始绘图
# ------------------------------------------------------------
fig, ax = plt.subplots(figsize=(8.4, 6.6), facecolor="#d9d9d9")
ax.set_facecolor("#d9d9d9")
# 平面小方块着色
pc = PatchCollection(
patches,
cmap="winter",
edgecolor="none",
linewidth=0.0,
zorder=1
)
pc.set_array(face_values_norm)
ax.add_collection(pc)
# 网格线
lc = LineCollection(
grid_segments,
colors=[(0, 0, 0, 0.45)],
linewidths=0.9,
zorder=2
)
ax.add_collection(lc)
# 三根坐标轴
ax.annotate(
"",
xy=x1_end,
xytext=origin,
arrowprops=dict(arrowstyle="-|>", lw=1.6, color="0.20"),
zorder=3
)
ax.annotate(
"",
xy=x2_end,
xytext=origin,
arrowprops=dict(arrowstyle="-|>", lw=1.6, color="0.20"),
zorder=3
)
ax.annotate(
"",
xy=y_end,
xytext=origin,
arrowprops=dict(arrowstyle="-|>", lw=1.6, color="0.20"),
zorder=3
)
# 轴标签
ax.text(
x1_end[0] + 0.18, x1_end[1] - 0.02,
r"$X_1$", fontsize=28, ha="left", va="center"
)
ax.text(
x2_end[0] - 0.08, x2_end[1] + 0.14,
r"$X_2$", fontsize=28, ha="center", va="center"
)
ax.text(
y_end[0] + 0.06, y_end[1] + 0.10,
r"$Y$", fontsize=28, ha="center", va="bottom"
)
# 残差线
for a, b in zip(p_fit, p_obs):
ax.plot(
[a[0], b[0]],
[a[1], b[1]],
color="0.25",
lw=1.4,
zorder=4
)
# 红色观测点
ax.scatter(
p_obs[:, 0], p_obs[:, 1],
s=44,
c="red",
edgecolors="none",
zorder=5
)
# ------------------------------------------------------------
# 11. 自动设置显示范围
# ------------------------------------------------------------
all_xy = np.vstack([
np.array([origin, x1_end, x2_end, y_end]),
p_obs,
p_fit
])
xmin, ymin = all_xy.min(axis=0)
xmax, ymax = all_xy.max(axis=0)
pad_x = 0.7
pad_y = 0.7
ax.set_xlim(xmin - pad_x, xmax + pad_x)
ax.set_ylim(ymin - pad_y, ymax + pad_y)
# 保证比例一致,否则图形会被拉伸
ax.set_aspect("equal")
# 去掉边框与刻度
ax.axis("off")
plt.tight_layout(pad=0.4)
plt.show()
# 如需保存图片,取消下面一行注释即可
# plt.savefig("ols_plane_james_style.png", dpi=300, bbox_inches="tight", facecolor=fig.get_facecolor())