最近开始看训练相关的东西(主要是了解,不持续学习是不行的),先不搞复杂框架,先把最核心的一条线搞明白。
我现在的理解很朴素:
大模型就是一个很复杂的函数,参数非常多。
训练这件事,本质上就是不断调这些参数(通常写成 w),让输出越来越接近目标答案。
背景
我之前主要做存储和后端,习惯先把链路走通,再补理论。
所以训练这块我也先用白话理解主流程,不急着一上来就啃完所有数学推导。
我现在理解的训练主流程
可以先记成 4 步:
forward:先用当前参数跑一遍,得到预测结果。loss:拿预测和真实答案做对比,算误差。backward:根据误差反向算梯度,知道参数该往哪边调。update:按梯度更新参数,然后进入下一轮。
一句话总结就是:
先猜 -> 看差多少 -> 反向修正 -> 再猜。
我用的最小例子(先把流程跑通)
这里的小脚本不是公共教程文件,而是我自己写的最小训练示例,用三次多项式去拟合 sin(x),目的只有一个:把训练主流程跑明白。
它大致就这几步:
X = [1, x, x^2, x^3]
pred = X @ w
loss = mse(pred, y)
grad = X.T @ (pred - y)
w = w - lr * grad
然后加上验证集、学习率调度、early stopping、checkpoint,这样就有了一个完整但不复杂的训练闭环。
附:完整示例脚本(可直接运行)
下面这份就是我现在用来理解 forward / loss / backward / update 的完整最小脚本。
# -*- coding: utf-8 -*-
import math
from pathlib import Path
import numpy as np
def build_features(x_value):
"""构造三次多项式特征矩阵 [1, x, x^2, x^3]。"""
return np.column_stack((np.ones_like(x_value), x_value, x_value ** 2, x_value ** 3))
def predict(feature_matrix, theta):
"""线性模型前向: y_hat = X @ theta。"""
return feature_matrix @ theta
def metrics(y_true, y_pred):
"""回归指标集合。"""
err = y_pred - y_true
sse = np.square(err).sum()
mse = sse / err.size
rmse = np.sqrt(mse)
mae = np.abs(err).mean()
ss_tot = np.square(y_true - y_true.mean()).sum()
r2 = 1.0 - sse / ss_tot if ss_tot > 0 else float("nan")
return {
"sse": sse,
"mse": mse,
"rmse": rmse,
"mae": mae,
"r2": r2,
}
def save_checkpoint(path, epoch, theta, best_val_mse, lr):
"""
checkpoint = 训练过程中的“存档点”。
常见用途:
1) 保留当前最优模型(通常按 val 指标)。
2) 训练中断后从该点继续,而不是从头开始。
"""
np.savez(
path,
epoch=np.array([epoch], dtype=np.int64),
theta=theta,
best_val_mse=np.array([best_val_mse], dtype=np.float64),
lr=np.array([lr], dtype=np.float64),
)
def load_checkpoint(path):
data = np.load(path)
epoch = int(data["epoch"][0])
theta = data["theta"].copy()
best_val_mse = float(data["best_val_mse"][0])
lr = float(data["lr"][0])
return epoch, theta, best_val_mse, lr
# 1) 训练配置(标准训练流程常见配置)
seed = 42
num_points = 2000
max_epochs = 20000
base_lr = 1e-6
min_lr = 1e-9
lr_decay = 0.5
lr_patience = 500
early_stop_patience = 2000
min_delta = 1e-12
log_every = 100
train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15
checkpoint_dir = Path(__file__).resolve().parent / "checkpoints"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
checkpoint_path = checkpoint_dir / "t1_best.npz"
rng = np.random.default_rng(seed)
# 2) 数据集构造
x_all = np.linspace(-math.pi, math.pi, num_points)
y_all = np.sin(x_all)
X_all = build_features(x_all)
# 3) train/val/test 划分
indices = rng.permutation(num_points)
n_train = int(num_points * train_ratio)
n_val = int(num_points * val_ratio)
n_test = num_points - n_train - n_val
train_idx = indices[:n_train]
val_idx = indices[n_train:n_train + n_val]
test_idx = indices[n_train + n_val:]
X_train, y_train = X_all[train_idx], y_all[train_idx]
X_val, y_val = X_all[val_idx], y_all[val_idx]
X_test, y_test = X_all[test_idx], y_all[test_idx]
# 4) 参数初始化
theta = rng.standard_normal(4)
lr = base_lr
best_theta = theta.copy()
best_val_mse = float("inf")
best_epoch = 0
bad_epochs_for_stop = 0
bad_epochs_for_lr = 0
print("=== 标准训练流程: 三次多项式拟合 sin(x) ===")
print(
f"config: seed={seed}, points={num_points}, split(train/val/test)="
f"{n_train}/{n_val}/{n_test}, max_epochs={max_epochs}, lr={base_lr}"
)
print(
f"early_stopping: patience={early_stop_patience}, min_delta={min_delta}; "
f"lr_scheduler(on plateau): patience={lr_patience}, decay={lr_decay}, min_lr={min_lr}"
)
print(f"checkpoint: {checkpoint_path}")
print("log fields: epoch, lr, train_mse, val_mse, train_r2, val_r2, grad_norm, best_val_mse")
# 5) 训练循环(epoch 级)
for epoch in range(1, max_epochs + 1):
# forward(train)
train_pred = predict(X_train, theta)
train_err = train_pred - y_train
train_stats = metrics(y_train, train_pred)
# backward(train): MSE 梯度,除以样本数后更稳定
grad = (2.0 / y_train.size) * (X_train.T @ train_err)
grad_norm = np.linalg.norm(grad)
# step
theta = theta - lr * grad
# evaluate(val)
val_pred = predict(X_val, theta)
val_stats = metrics(y_val, val_pred)
improved = val_stats["mse"] < (best_val_mse - min_delta)
if improved:
best_val_mse = val_stats["mse"]
best_theta = theta.copy()
best_epoch = epoch
bad_epochs_for_stop = 0
bad_epochs_for_lr = 0
save_checkpoint(checkpoint_path, epoch, best_theta, best_val_mse, lr)
else:
bad_epochs_for_stop += 1
bad_epochs_for_lr += 1
# lr on plateau
if bad_epochs_for_lr >= lr_patience:
new_lr = max(lr * lr_decay, min_lr)
if new_lr < lr:
print(
f"[epoch {epoch}] val plateau {bad_epochs_for_lr} epochs -> "
f"lr {lr:.3e} -> {new_lr:.3e}"
)
lr = new_lr
bad_epochs_for_lr = 0
# logging
if epoch == 1 or epoch % log_every == 0:
print(
f"[{epoch:5d}/{max_epochs}] lr={lr:.3e} "
f"train_mse={train_stats['mse']:.6f} val_mse={val_stats['mse']:.6f} "
f"train_r2={train_stats['r2']:.6f} val_r2={val_stats['r2']:.6f} "
f"grad_norm={grad_norm:.3e} best_val_mse={best_val_mse:.6f}"
)
# early stopping
if bad_epochs_for_stop >= early_stop_patience:
print(
f"[epoch {epoch}] early stopping triggered: "
f"no val improvement in {early_stop_patience} epochs."
)
break
# 6) 载入最佳 checkpoint(若存在),并在 train/val/test 统一评估
if checkpoint_path.exists():
ckpt_epoch, ckpt_theta, ckpt_best_val_mse, ckpt_lr = load_checkpoint(checkpoint_path)
best_theta = ckpt_theta
best_epoch = ckpt_epoch
best_val_mse = ckpt_best_val_mse
else:
ckpt_lr = lr
train_final = metrics(y_train, predict(X_train, best_theta))
val_final = metrics(y_val, predict(X_val, best_theta))
test_final = metrics(y_test, predict(X_test, best_theta))
a, b, c, d = best_theta
print("\n=== 训练结束(使用最佳 checkpoint)===")
print(f"best_epoch={best_epoch}, best_val_mse={best_val_mse:.6f}, ckpt_lr={ckpt_lr:.3e}")
print(f"model: y = {a:.12f} + {b:.12f} x + {c:.12f} x^2 + {d:.12f} x^3")
print(
f"train: mse={train_final['mse']:.6f}, rmse={train_final['rmse']:.6f}, "
f"mae={train_final['mae']:.6f}, r2={train_final['r2']:.6f}"
)
print(
f"val : mse={val_final['mse']:.6f}, rmse={val_final['rmse']:.6f}, "
f"mae={val_final['mae']:.6f}, r2={val_final['r2']:.6f}"
)
print(
f"test : mse={test_final['mse']:.6f}, rmse={test_final['rmse']:.6f}, "
f"mae={test_final['mae']:.6f}, r2={test_final['r2']:.6f}"
)
print(f"checkpoint saved at: {checkpoint_path}")
sample_x = np.array([-math.pi, -math.pi / 2, 0.0, math.pi / 2, math.pi])
sample_feature = build_features(sample_x)
sample_pred = predict(sample_feature, best_theta)
sample_true = np.sin(sample_x)
print("\n关键点对比 (x, pred, true, abs_err):")
for x_value, pred_value, true_value in zip(sample_x, sample_pred, sample_true):
abs_err = abs(pred_value - true_value)
print(f" x={x_value:+.6f}, pred={pred_value:+.6f}, true={true_value:+.6f}, abs_err={abs_err:.6f}")