首页 >> 知识 >> 在PyTorch中为可学习参数施加约束或正则项的方法

在PyTorch中为可学习参数施加约束或正则项的方法

根据不同的需求,在PyTorch中有时需要为模型的可学习参数施加自定义的约束或正则项(regular term),下面具体介绍在PyTorch中为可学习参数施加约束或正则项的方法,先看一下为损失函数(Loss function)施加正则项的具体形式,如下为L2正则项:

在上式中,是训练误差关于可学习参数w的函数,右边的第二项表示L2正则项。在PyTorch中L2正则项是默认内置实现的,其中的weight_decay就表示L2正则项的超参数。具体如下:

optimizer = optim.SGD(net.parameters(), lr=0.01, weight_decay=0.01)

根据不同的需求,怎样自定义自己的正则项函数呢?具体示例如下:

import torchtorch.manual_seed(1)N, D_in, H, D_out = 10, 5, 5, 1x = torch.randn(N, D_in)y = torch.randn(N, D_out)model = torch.nn.Sequential( torch.nn.Linear(D_in, H), torch.nn.ReLU(), torch.nn.Linear(H, D_out),)criterion = torch.nn.MSELoss()lr = 1e-4weight_decay = 0 # for torch.optim.SGDlmbd = 0.9 # for custom L2 regularizationoptimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay)for t in range(100): y_pred = model(x) # Compute and print loss. loss = criterion(y_pred, y) optimizer.zero_grad() reg_loss = None for param in model.parameters(): if reg_loss is None: reg_loss = 0.5 * torch.sum(param**2) else: reg_loss = reg_loss + 0.5 * param.norm(2)**2 loss += lmbd * reg_loss loss.backward() optimizer.step()for name, param in model.named_parameters(): print(name, param)

在上述代码中,如下部分可根据自己的需求,自定义自己的正则项约束:

reg_loss = None for param in model.parameters(): if reg_loss is None: reg_loss = 0.5 * torch.sum(param**2) else: reg_loss = reg_loss + 0.5 * param.norm(2)**2

 

如果您觉得我的文章对您有所帮助,欢迎扫码进行赞赏!

参考:

1. How does one implement Weight regularization (l1 or l2) manually without optimum?

2. torch.norm

3. How to add a L2 regularization term in my loss function?

网站地图