Training Process in model.FeedForward

Softmax Executor

Load Inputs

After softmax executor is initialized, the executor loads data. Inputs are embedding (batchsize x embedding_size) and ground truth labels of current batch.

softmax_executor.load_data(cpu_out_feature, total_label)

After load_data, softmax executor's part_in_feature_n and in_feature attributes are modified.

Forward

Then forward operation in softmax executor is abstracted as forward method,

softmax_executor.forward(max_exp, sum_exp)

To understand what will happen in forward, we need to dive deep in machine_parallel_loss codes.

When using margin-based softmax (e.g. arcface), fully connected layer weights needs L2 normalization operation, weightl2_forward().

self.weightl2_forward()

And softmax executor's part_weight_n and weight_param are modified -- being L2 normalized.

Embeddings (cpu_out_feature) go through fully connected layer to get outputs and change values of part_fully_connected_out_n and fully_connected_out.

self.fullyconnect_forward(self.local_name['in_feature'])

For margin-based softmax, execute margin_forward,

self.margin_forward('fully_connected_out', margin_s, margin_m)

gt_index is a list whose length is the number of total contexts. gt_index's ith item is a length-2 list, the first item (ndarray shape is 1xn) is indexes of samples whose gt labes are on context-i, the second item (ndarray shape is n) is ground truth label of the above samples. Therefore within a mini-batch, logits of the samples relating their ground truth labels can be indexed by [gt_index[i][0], gt_index[i][1]].

Now it is obvious that margin_tmp is logits corresponding to ground truth labels. self.margin_tmp_space is a list of all logits on all contexts. Dividing logits by margin-s give us cosine similarities of embedding and its corresponding class center (relating FC weight).

Output of fully connected layer is used for softmax forward,

self.softmax_forward('fully_connected_out')

Question: what does mpi_gather_bcast do?

Backward

The gradient of loss with respect to softmax output is softmax output where relating ground truth index is deducted by 1. The loss is obtained by the natural logarithm of softmax output (probability) corresponding to its ground truth.

margin_func

The way margin_m is applied to logits deserves a dedicated space. Details can be found in _margin_func function.

The goal is simple, input margin_s * cos(theta) output margin_s * (cos(theta + margin_m)).

cos_t = in_data / margin_s = cos(theta0)
cos_m = cos(margin_m) = cos(m) # m ~= margin_m
mm = margin_m * sin(pi - m) = m * sin(m)

Since easy_margin is true by default, we can carry on as,

cond = relu(cos_t) = relu(cos(theta0)) # 0 is theta0 is larger than pi/2
body = cos_t * cos_t = (cos(theta0))^2
body = 1 - body = (sin(theta0))^2
sint_t = sqrt(body) = sin(theta0)

out_data = cos_t * cos_m = cos(theta0) * cos(m)
b = sin_t * sin_m = sin(theta0) * cos(m)
out_data = out_data - b = cos(theta0 + m) # there you go
out_data = out_data * margin_s = margin_s * cos(theta0 + m)
zy_keep = in_data = margin_s * cos(theta0)

out_data = where(cond, out_data, zy_keep)

The last operation is interpreted as, if input is positive i.e. theta0 is between (0, pi/2), apply margin_m to it, otherwise i.e. theta0 is larger than pi/2, return the input as is directly.

results matching ""

    No results matching ""