Administrator
发布于 2026-02-22 / 23 阅读
0
0

训练初了解:把大模型看成一个复杂函数(通俗版)

最近开始看训练相关的东西(主要是了解,不持续学习是不行的),先不搞复杂框架,先把最核心的一条线搞明白。

我现在的理解很朴素:

大模型就是一个很复杂的函数,参数非常多。

训练这件事,本质上就是不断调这些参数(通常写成 w),让输出越来越接近目标答案。

背景

我之前主要做存储和后端,习惯先把链路走通,再补理论。

所以训练这块我也先用白话理解主流程,不急着一上来就啃完所有数学推导。

我现在理解的训练主流程

可以先记成 4 步:

  1. forward:先用当前参数跑一遍,得到预测结果。
  2. loss:拿预测和真实答案做对比,算误差。
  3. backward:根据误差反向算梯度,知道参数该往哪边调。
  4. update:按梯度更新参数,然后进入下一轮。

一句话总结就是:

先猜 -> 看差多少 -> 反向修正 -> 再猜。

我用的最小例子(先把流程跑通)

这里的小脚本不是公共教程文件,而是我自己写的最小训练示例,用三次多项式去拟合 sin(x),目的只有一个:把训练主流程跑明白。

它大致就这几步:

X = [1, x, x^2, x^3]
pred = X @ w
loss = mse(pred, y)
grad = X.T @ (pred - y)
w = w - lr * grad

然后加上验证集、学习率调度、early stopping、checkpoint,这样就有了一个完整但不复杂的训练闭环。

附:完整示例脚本(可直接运行)

下面这份就是我现在用来理解 forward / loss / backward / update 的完整最小脚本。

# -*- coding: utf-8 -*-
import math
from pathlib import Path

import numpy as np


def build_features(x_value):
    """构造三次多项式特征矩阵 [1, x, x^2, x^3]。"""
    return np.column_stack((np.ones_like(x_value), x_value, x_value ** 2, x_value ** 3))


def predict(feature_matrix, theta):
    """线性模型前向: y_hat = X @ theta。"""
    return feature_matrix @ theta


def metrics(y_true, y_pred):
    """回归指标集合。"""
    err = y_pred - y_true
    sse = np.square(err).sum()
    mse = sse / err.size
    rmse = np.sqrt(mse)
    mae = np.abs(err).mean()
    ss_tot = np.square(y_true - y_true.mean()).sum()
    r2 = 1.0 - sse / ss_tot if ss_tot > 0 else float("nan")
    return {
        "sse": sse,
        "mse": mse,
        "rmse": rmse,
        "mae": mae,
        "r2": r2,
    }


def save_checkpoint(path, epoch, theta, best_val_mse, lr):
    """
    checkpoint = 训练过程中的“存档点”。
    常见用途:
    1) 保留当前最优模型(通常按 val 指标)。
    2) 训练中断后从该点继续,而不是从头开始。
    """
    np.savez(
        path,
        epoch=np.array([epoch], dtype=np.int64),
        theta=theta,
        best_val_mse=np.array([best_val_mse], dtype=np.float64),
        lr=np.array([lr], dtype=np.float64),
    )


def load_checkpoint(path):
    data = np.load(path)
    epoch = int(data["epoch"][0])
    theta = data["theta"].copy()
    best_val_mse = float(data["best_val_mse"][0])
    lr = float(data["lr"][0])
    return epoch, theta, best_val_mse, lr


# 1) 训练配置(标准训练流程常见配置)
seed = 42
num_points = 2000
max_epochs = 20000
base_lr = 1e-6
min_lr = 1e-9
lr_decay = 0.5
lr_patience = 500
early_stop_patience = 2000
min_delta = 1e-12
log_every = 100

train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15

checkpoint_dir = Path(__file__).resolve().parent / "checkpoints"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
checkpoint_path = checkpoint_dir / "t1_best.npz"

rng = np.random.default_rng(seed)

# 2) 数据集构造
x_all = np.linspace(-math.pi, math.pi, num_points)
y_all = np.sin(x_all)
X_all = build_features(x_all)

# 3) train/val/test 划分
indices = rng.permutation(num_points)
n_train = int(num_points * train_ratio)
n_val = int(num_points * val_ratio)
n_test = num_points - n_train - n_val

train_idx = indices[:n_train]
val_idx = indices[n_train:n_train + n_val]
test_idx = indices[n_train + n_val:]

X_train, y_train = X_all[train_idx], y_all[train_idx]
X_val, y_val = X_all[val_idx], y_all[val_idx]
X_test, y_test = X_all[test_idx], y_all[test_idx]

# 4) 参数初始化
theta = rng.standard_normal(4)
lr = base_lr

best_theta = theta.copy()
best_val_mse = float("inf")
best_epoch = 0

bad_epochs_for_stop = 0
bad_epochs_for_lr = 0

print("=== 标准训练流程: 三次多项式拟合 sin(x) ===")
print(
    f"config: seed={seed}, points={num_points}, split(train/val/test)="
    f"{n_train}/{n_val}/{n_test}, max_epochs={max_epochs}, lr={base_lr}"
)
print(
    f"early_stopping: patience={early_stop_patience}, min_delta={min_delta}; "
    f"lr_scheduler(on plateau): patience={lr_patience}, decay={lr_decay}, min_lr={min_lr}"
)
print(f"checkpoint: {checkpoint_path}")
print("log fields: epoch, lr, train_mse, val_mse, train_r2, val_r2, grad_norm, best_val_mse")

# 5) 训练循环(epoch 级)
for epoch in range(1, max_epochs + 1):
    # forward(train)
    train_pred = predict(X_train, theta)
    train_err = train_pred - y_train
    train_stats = metrics(y_train, train_pred)

    # backward(train): MSE 梯度,除以样本数后更稳定
    grad = (2.0 / y_train.size) * (X_train.T @ train_err)
    grad_norm = np.linalg.norm(grad)

    # step
    theta = theta - lr * grad

    # evaluate(val)
    val_pred = predict(X_val, theta)
    val_stats = metrics(y_val, val_pred)

    improved = val_stats["mse"] < (best_val_mse - min_delta)
    if improved:
        best_val_mse = val_stats["mse"]
        best_theta = theta.copy()
        best_epoch = epoch
        bad_epochs_for_stop = 0
        bad_epochs_for_lr = 0
        save_checkpoint(checkpoint_path, epoch, best_theta, best_val_mse, lr)
    else:
        bad_epochs_for_stop += 1
        bad_epochs_for_lr += 1

    # lr on plateau
    if bad_epochs_for_lr >= lr_patience:
        new_lr = max(lr * lr_decay, min_lr)
        if new_lr < lr:
            print(
                f"[epoch {epoch}] val plateau {bad_epochs_for_lr} epochs -> "
                f"lr {lr:.3e} -> {new_lr:.3e}"
            )
            lr = new_lr
        bad_epochs_for_lr = 0

    # logging
    if epoch == 1 or epoch % log_every == 0:
        print(
            f"[{epoch:5d}/{max_epochs}] lr={lr:.3e} "
            f"train_mse={train_stats['mse']:.6f} val_mse={val_stats['mse']:.6f} "
            f"train_r2={train_stats['r2']:.6f} val_r2={val_stats['r2']:.6f} "
            f"grad_norm={grad_norm:.3e} best_val_mse={best_val_mse:.6f}"
        )

    # early stopping
    if bad_epochs_for_stop >= early_stop_patience:
        print(
            f"[epoch {epoch}] early stopping triggered: "
            f"no val improvement in {early_stop_patience} epochs."
        )
        break

# 6) 载入最佳 checkpoint(若存在),并在 train/val/test 统一评估
if checkpoint_path.exists():
    ckpt_epoch, ckpt_theta, ckpt_best_val_mse, ckpt_lr = load_checkpoint(checkpoint_path)
    best_theta = ckpt_theta
    best_epoch = ckpt_epoch
    best_val_mse = ckpt_best_val_mse
else:
    ckpt_lr = lr

train_final = metrics(y_train, predict(X_train, best_theta))
val_final = metrics(y_val, predict(X_val, best_theta))
test_final = metrics(y_test, predict(X_test, best_theta))

a, b, c, d = best_theta

print("\n=== 训练结束(使用最佳 checkpoint)===")
print(f"best_epoch={best_epoch}, best_val_mse={best_val_mse:.6f}, ckpt_lr={ckpt_lr:.3e}")
print(f"model: y = {a:.12f} + {b:.12f} x + {c:.12f} x^2 + {d:.12f} x^3")
print(
    f"train: mse={train_final['mse']:.6f}, rmse={train_final['rmse']:.6f}, "
    f"mae={train_final['mae']:.6f}, r2={train_final['r2']:.6f}"
)
print(
    f"val  : mse={val_final['mse']:.6f}, rmse={val_final['rmse']:.6f}, "
    f"mae={val_final['mae']:.6f}, r2={val_final['r2']:.6f}"
)
print(
    f"test : mse={test_final['mse']:.6f}, rmse={test_final['rmse']:.6f}, "
    f"mae={test_final['mae']:.6f}, r2={test_final['r2']:.6f}"
)
print(f"checkpoint saved at: {checkpoint_path}")

sample_x = np.array([-math.pi, -math.pi / 2, 0.0, math.pi / 2, math.pi])
sample_feature = build_features(sample_x)
sample_pred = predict(sample_feature, best_theta)
sample_true = np.sin(sample_x)

print("\n关键点对比 (x, pred, true, abs_err):")
for x_value, pred_value, true_value in zip(sample_x, sample_pred, sample_true):
    abs_err = abs(pred_value - true_value)
    print(f"  x={x_value:+.6f}, pred={pred_value:+.6f}, true={true_value:+.6f}, abs_err={abs_err:.6f}")


评论