遙想當年,AlphaGo的Master版本,在完勝柯潔九段之后不久,就被后輩AlphaGo Zero (簡稱狗零) 擊潰了。 從一只完全不懂圍棋的AI,到打敗Master,狗零只用了21天。 而且,它不需要用人類知識來喂養(yǎng),成為頂尖棋手全靠自學。 如果能培育這樣一只AI,即便自己不會下棋,也可以很驕傲吧。 于是,來自巴黎的少年Dylan Djian (簡稱小笛) ,就照著狗零的論文去實現(xiàn)了一下。 他給自己的AI棋手起名SuperGo,也提供了代碼 (傳送門見文底) 。 除此之外,還有教程—— 一個身子兩個頭智能體分成三個部分: 一是特征提取器 (Feature Extractor) ,二是策略網(wǎng)絡 (Policy Network) ,三是價值網(wǎng)絡 (Value Network) 。 于是,狗零也被親切地稱為“雙頭怪”。特征提取器是身子,其他兩個網(wǎng)絡是腦子。 特征提取器 特征提取模型,是個殘差網(wǎng)絡 (ResNet) ,就是給普通CNN加上了跳層連接 (Skip Connection) , 讓梯度的傳播更加通暢。 跳躍的樣子,寫成代碼就是: 1class BasicBlock(nn.Module): 2 ''' 3 Basic residual block with 2 convolutions and a skip connection 4 before the last ReLU activation. 5 ''' 6 7 def __init__(self, inplanes, planes, stride=1, downsample=None): 8 super(BasicBlock, self).__init__() 9 10 self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, 11 stride=stride, padding=1, bias=False) 12 self.bn1 = nn.BatchNorm2d(planes) 13 14 self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 15 stride=stride, padding=1, bias=False) 16 self.bn2 = nn.BatchNorm2d(planes) 17 18 19 def forward(self, x): 20 residual = x 21 22 out = self.conv1(x) 23 out = F.relu(self.bn1(out)) 24 25 out = self.conv2(out) 26 out = self.bn2(out) 27 28 out += residual 29 out = F.relu(out) 30 31 return out 然后,把它加到特征提取模型里面去: 1class Extractor(nn.Module): 2 def __init__(self, inplanes, outplanes): 3 super(Extractor, self).__init__() 4 self.conv1 = nn.Conv2d(inplanes, outplanes, stride=1, 5 kernel_size=3, padding=1, bias=False) 6 self.bn1 = nn.BatchNorm2d(outplanes) 7 8 for block in range(BLOCKS): 9 setattr(self, 'res{}'.format(block), \ 10 BasicBlock(outplanes, outplanes)) 11 12 13 def forward(self, x): 14 x = F.relu(self.bn1(self.conv1(x))) 15 for block in range(BLOCKS - 1): 16 x = getattr(self, 'res{}'.format(block))(x) 17 18 feature_maps = getattr(self, 'res{}'.format(BLOCKS - 1))(x) 19 return feature_maps 策略網(wǎng)絡 策略網(wǎng)絡就是普通的CNN了,里面有個批量標準化 (Batch Normalization) ,還有一個全連接層,輸出概率分布。 1class PolicyNet(nn.Module): 2 def __init__(self, inplanes, outplanes): 3 super(PolicyNet, self).__init__() 4 self.outplanes = outplanes 5 self.conv = nn.Conv2d(inplanes, 1, kernel_size=1) 6 self.bn = nn.BatchNorm2d(1) 7 self.logsoftmax = nn.LogSoftmax(dim=1) 8 self.fc = nn.Linear(outplanes - 1, outplanes) 9 10 11 def forward(self, x): 12 x = F.relu(self.bn(self.conv(x))) 13 x = x.view(-1, self.outplanes - 1) 14 x = self.fc(x) 15 probas = self.logsoftmax(x).exp() 16 17 return probas 價值網(wǎng)絡 這個網(wǎng)絡稍微復雜一點。除了標配之外,還要再多加一個全連接層。最后,用雙曲正切 (Hyperbolic Tangent) 算出 (-1,1) 之間的數(shù)值,來表示當前狀態(tài)下的贏面多大。 代碼長這樣—— 1class ValueNet(nn.Module): 2 def __init__(self, inplanes, outplanes): 3 super(ValueNet, self).__init__() 4 self.outplanes = outplanes 5 self.conv = nn.Conv2d(inplanes, 1, kernel_size=1) 6 self.bn = nn.BatchNorm2d(1) 7 self.fc1 = nn.Linear(outplanes - 1, 256) 8 self.fc2 = nn.Linear(256, 1) 9 10 11 def forward(self, x): 12 x = F.relu(self.bn(self.conv(x))) 13 x = x.view(-1, self.outplanes - 1) 14 x = F.relu(self.fc1(x)) 15 winning = F.tanh(self.fc2(x)) 16 return winning 未雨綢繆的樹狗零,還有一個很重要的組成部分,就是蒙特卡洛樹搜索 (MCTS) 。 它可以讓AI棋手提前找出,勝率最高的落子點。 在模擬器里,模擬對方的下一手,以及再下一手,給出應對之策,所以提前的遠不止是一步。 節(jié)點 (Node) 樹上的每一個節(jié)點,都代表一種不同的局勢,有不同的統(tǒng)計數(shù)據(jù): 每個節(jié)點被經(jīng)過的次數(shù)n,總動作值w,經(jīng)過這一點的先驗概率p,平均動作值q (q=w/n) ,還有從別處來到這個節(jié)點走的那一步,以及從這個節(jié)點出發(fā)、所有可能的下一步。 1class Node: 2 def __init__(self, parent=None, proba=None, move=None): 3 self.p = proba 4 self.n = 0 5 self.w = 0 6 self.q = 0 7 self.children = [] 8 self.parent = parent 9 self.move = move 部署 (Rollout) 第一步是PUCT (多項式上置信樹) 算法,選擇能讓PUCT函數(shù) (下圖) 的某個變體 (Variant) 最大化,的走法。 寫成代碼的話—— 1def select(nodes, c_puct=C_PUCT): 2 ' Optimized version of the selection based of the PUCT formula ' 3 4 total_count = 0 5 for i in range(nodes.shape[0]): 6 total_count += nodes[i][1] 7 8 action_scores = np.zeros(nodes.shape[0]) 9 for i in range(nodes.shape[0]): 10 action_scores[i] = nodes[i][0] + c_puct * nodes[i][2] * \ 11 (np.sqrt(total_count) / (1 + nodes[i][1])) 12 13 equals = np.where(action_scores == np.max(action_scores))[0] 14 if equals.shape[0] > 0: 15 return np.random.choice(equals) 16 return equals[0] 結束 (Ending) 選擇在不停地進行,直至到達一個葉節(jié)點 (Leaf Node) ,而這個節(jié)點還沒有往下生枝。 1def is_leaf(self): 2 ''' Check whether a node is a leaf or not ''' 3 4 return len(self.children) == 0 到了葉節(jié)點,那里的一個隨機狀態(tài)就會被評估,得出所有“下一步”的概率。 所有被禁的落子點,概率會變成零,然后重新把總概率歸為1。 然后,這個葉節(jié)點就會生出枝節(jié) (都是可以落子的位置,概率不為零的那些) 。代碼如下—— 1def expand(self, probas): 2 self.children = [Node(parent=self, move=idx, proba=probas[idx]) \ 3 for idx in range(probas.shape[0]) if probas[idx] > 0] 更新一下 枝節(jié)生好之后,這個葉節(jié)點和它的媽媽們,身上的統(tǒng)計數(shù)據(jù)都會更新,用的是下面這兩串代碼。 1def update(self, v): 2 ''' Update the node statistics after a rollout ''' 3 4 self.w = self.w + v 5 self.q = self.w / self.n if self.n > 0 else 0 1while current_node.parent: 2 current_node.update(v) 3 current_node = current_node.parent 選擇落子點模擬器搭好了,每個可能的“下一步”,都有了自己的統(tǒng)計數(shù)據(jù)。 按照這些數(shù)據(jù),算法會選擇其中一步,真要落子的地方。 選擇有兩種,一就是選擇被模擬的次數(shù)最多的點。試用于測試和實戰(zhàn)。 另外一種,隨機 (Stochastically) 選擇,把節(jié)點被經(jīng)過的次數(shù)轉換成概率分布,用的是以下代碼—— 1 total = np.sum(action_scores) 2 probas = action_scores / total 3 move = np.random.choice(action_scores.shape[0], p=probas) 后者適用于訓練,讓AlphaGo探索更多可能的選擇。 三位一體的修煉狗零的修煉分為三個過程,是異步的。 一是自對弈 (Self-Play) ,用來生成數(shù)據(jù)。 1def self_play(): 2 while True: 3 new_player, checkpoint = load_player() 4 if new_player: 5 player = new_player 6 7 ## Create the self-play match queue of processes 8 results = create_matches(player, cores=PARALLEL_SELF_PLAY, 9 match_number=SELF_PLAY_MATCH) 10 for _ in range(SELF_PLAY_MATCH): 11 result = results.get() 12 db.insert({ 13 'game': result, 14 'id': game_id 15 }) 16 game_id += 1 二是訓練 (Training) ,拿新鮮生成的數(shù)據(jù),來改進當前的神經(jīng)網(wǎng)絡。 1def train(): 2 criterion = AlphaLoss() 3 dataset = SelfPlayDataset() 4 player, checkpoint = load_player(current_time, loaded_version) 5 optimizer = create_optimizer(player, lr, 6 param=checkpoint['optimizer']) 7 best_player = deepcopy(player) 8 dataloader = DataLoader(dataset, collate_fn=collate_fn, \ 9 batch_size=BATCH_SIZE, shuffle=True) 10 11 while True: 12 for batch_idx, (state, move, winner) in enumerate(dataloader): 13 14 ## Evaluate a copy of the current network 15 if total_ite % TRAIN_STEPS == 0: 16 pending_player = deepcopy(player) 17 result = evaluate(pending_player, best_player) 18 19 if result: 20 best_player = pending_player 21 22 example = { 23 'state': state, 24 'winner': winner, 25 'move' : move 26 } 27 optimizer.zero_grad() 28 winner, probas = pending_player.predict(example['state']) 29 30 loss = criterion(winner, example['winner'], \ 31 probas, example['move']) 32 loss.backward() 33 optimizer.step() 34 35 ## Fetch new games 36 if total_ite % REFRESH_TICK == 0: 37 last_id = fetch_new_games(collection, dataset, last_id) 訓練用的損失函數(shù)表示如下: 1class AlphaLoss(torch.nn.Module): 2 def __init__(self): 3 super(AlphaLoss, self).__init__() 4 5 def forward(self, pred_winner, winner, pred_probas, probas): 6 value_error = (winner - pred_winner) ** 2 7 policy_error = torch.sum((-probas * 8 (1e-6 + pred_probas).log()), 1) 9 total_error = (value_error.view(-1) + policy_error).mean() 10 return total_error 三是評估 (Evaluation) ,看訓練過的智能體,比起正在生成數(shù)據(jù)的智能體,是不是更優(yōu)秀了 (最優(yōu)秀者回到第一步,繼續(xù)生成數(shù)據(jù)) 。 1def evaluate(player, new_player): 2 results = play(player, opponent=new_player) 3 black_wins = 0 4 white_wins = 0 5 6 for result in results: 7 if result[0] == 1: 8 white_wins += 1 9 elif result[0] == 0: 10 black_wins += 1 11 12 ## Check if the trained player (black) is better than 13 ## the current best player depending on the threshold 14 if black_wins >= EVAL_THRESH * len(results): 15 return True 16 return False 第三部分很重要,要不斷選出最優(yōu)的網(wǎng)絡,來不斷生成高質量的數(shù)據(jù),才能提升AI的棋藝。 三個環(huán)節(jié)周而復始,才能養(yǎng)成強大的棋手。 年幼的SuperGo小笛用學校的服務器訓練了AI棋手一星期。 SuperGo還年幼,是在9x9棋盤上訓練的。 小笛說,他的AI現(xiàn)在好像還不懂生死一類的事,但應該已經(jīng)知道圍棋是個搶地盤的游戲了。 雖然,沒有訓練出什么超神的棋手,但這次嘗試依然值得慶祝。 Reddit上面也有同仁發(fā)來賀電。 △ 有前途的意思 有志于AI圍棋的各位,也可以試一試這個PyTorch實現(xiàn)。 代碼實現(xiàn)傳送門: https://github.com/dylandjian/SuperGo 教程原文傳送門: https://dylandjian./alphago-zero/ AlphaGo Zero論文傳送門: https://www./articles/nature24270.epdf 最后一句昨天 (8月2日) ,是柯潔的生日。 — 完 — 誠摯招聘 |
|
|