Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing out coatnet style relative attention #182

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

Tilps
Copy link
Contributor

@Tilps Tilps commented Oct 31, 2021

multi_head_relative_attention is a bit overkill, I forked the entire multi_head_attention from keras, it supports all kinds of stuff, in terms of dimensions, but with the relative logic added, it is now limited to NHWC input despite all the options indicating otherwise.

@Tilps
Copy link
Contributor Author

Tilps commented Oct 31, 2021

net.proto to go with my saving tweaks.
/*
This file is part of Leela Chess Zero.
Copyright (C) 2018 The LCZero Authors

Leela Chess is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Leela Chess is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with Leela Chess. If not, see http://www.gnu.org/licenses/.

Additional permission under GNU GPL version 3 section 7

If you modify this Program, or any covered work, by linking or
combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA
Toolkit and the NVIDIA CUDA Deep Neural Network library (or a
modified version of those libraries), containing parts covered by the
terms of the respective license agreement, the licensors of this
Program grant you additional permission to convey the resulting work.
*/
syntax = "proto2";

package pblczero;

message EngineVersion {
optional uint32 major = 1;
optional uint32 minor = 2;
optional uint32 patch = 3;
}

message Weights {
message Layer {
optional float min_val = 1;
optional float max_val = 2;
optional bytes params = 3;
}

message ConvBlock {
optional Layer weights = 1;
optional Layer biases = 2;
optional Layer bn_means = 3;
optional Layer bn_stddivs = 4;
optional Layer bn_gammas = 5;
optional Layer bn_betas = 6;
}

message SEunit {
// Squeeze-excitation unit (https://arxiv.org/abs/1709.01507)
// weights and biases of the two fully connected layers.
optional Layer w1 = 1;
optional Layer b1 = 2;
optional Layer w2 = 3;
optional Layer b2 = 4;
}

message Residual {
optional ConvBlock conv1 = 1;
optional ConvBlock conv2 = 2;
optional SEunit se = 3;
}

message MHRA {
optional Layer lngammas = 1;
optional Layer lnbetas = 2;
optional Layer rel_bias_table = 3;
optional Layer qw = 4;
optional Layer qb = 5;
optional Layer kw = 6;
optional Layer kb = 7;
optional Layer vw = 8;
optional Layer vb = 9;
optional Layer ow = 10;
optional Layer ob = 11;
}

message FFN {
optional Layer lngammas = 1;
optional Layer lnbetas = 2;
optional Layer fc1w = 3;
optional Layer fc1b = 4;
optional Layer fc2w = 5;
optional Layer fc2b = 6;
}

message RRA {
optional MHRA mhra = 1;
optional FFN ffn = 2;
}

// Input convnet.
optional ConvBlock input = 1;

// Residual tower.
repeated Residual residual = 2;

// Policy head
// Extra convolution for AZ-style policy head
optional ConvBlock policy1 = 11;
optional ConvBlock policy = 3;
optional Layer ip_pol_w = 4;
optional Layer ip_pol_b = 5;

// Value head
optional ConvBlock value = 6;
optional Layer ip1_val_w = 7;
optional Layer ip1_val_b = 8;
optional Layer ip2_val_w = 9;
optional Layer ip2_val_b = 10;

// Moves left head
optional ConvBlock moves_left = 12;
optional Layer ip1_mov_w = 13;
optional Layer ip1_mov_b = 14;
optional Layer ip2_mov_w = 15;
optional Layer ip2_mov_b = 16;

// rra tower.
repeated RRA rra = 17;
}

message TrainingParams {
optional uint32 training_steps = 1;
optional float learning_rate = 2;
optional float mse_loss = 3;
optional float policy_loss = 4;
optional float accuracy = 5;
optional string lc0_params = 6;
}

message NetworkFormat {
// Format to encode the input planes with. Used by position encoder.
enum InputFormat {
INPUT_UNKNOWN = 0;
INPUT_CLASSICAL_112_PLANE = 1;
INPUT_112_WITH_CASTLING_PLANE = 2;
INPUT_112_WITH_CANONICALIZATION = 3;
INPUT_112_WITH_CANONICALIZATION_HECTOPLIES = 4;
INPUT_112_WITH_CANONICALIZATION_HECTOPLIES_ARMAGEDDON = 132;
INPUT_112_WITH_CANONICALIZATION_V2 = 5;
INPUT_112_WITH_CANONICALIZATION_V2_ARMAGEDDON = 133;
}
optional InputFormat input = 1;

// Output format of the NN. Used by search code to interpret results.
enum OutputFormat {
OUTPUT_UNKNOWN = 0;
OUTPUT_CLASSICAL = 1;
OUTPUT_WDL = 2;
}
optional OutputFormat output = 2;

// Network architecture. Used by backends to build the network.
enum NetworkStructure {
// Networks without PolicyFormat or ValueFormat specified
NETWORK_UNKNOWN = 0;
NETWORK_CLASSICAL = 1;
NETWORK_SE = 2;
// Networks with PolicyFormat and ValueFormat specified
NETWORK_CLASSICAL_WITH_HEADFORMAT = 3;
NETWORK_SE_WITH_HEADFORMAT = 4;
}
optional NetworkStructure network = 3;

// Policy head architecture
enum PolicyFormat {
POLICY_UNKNOWN = 0;
POLICY_CLASSICAL = 1;
POLICY_CONVOLUTION = 2;
}
optional PolicyFormat policy = 4;

// Value head architecture
enum ValueFormat {
VALUE_UNKNOWN = 0;
VALUE_CLASSICAL = 1;
VALUE_WDL = 2;
VALUE_PARAM = 3;
}
optional ValueFormat value = 5;

// Moves left head architecture
enum MovesLeftFormat {
MOVES_LEFT_NONE = 0;
MOVES_LEFT_V1 = 1;
}
optional MovesLeftFormat moves_left = 6;
}

message Format {
enum Encoding {
UNKNOWN = 0;
LINEAR16 = 1;
}

optional Encoding weights_encoding = 1;
// If network_format is missing, it's assumed to have
// INPUT_CLASSICAL_112_PLANE / OUTPUT_CLASSICAL / NETWORK_CLASSICAL format.
optional NetworkFormat network_format = 2;
}

message Net {
optional fixed32 magic = 1;
optional string license = 2;
optional EngineVersion min_version = 3;
optional Format format = 4;
optional TrainingParams training_params = 5;
optional Weights weights = 10;
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant