GEMV 8: Routes and Fabric DSDs, Part III

GEMV 8: Routes and Fabric DSDs, Part III

Continuing from the previous example, we now extend the GEMV computation to a kernel_x_dim x kernel_y_dim rectangle of PEs. We make one simplification, enforcing that M is a multiple of kernel_y_dim, and N is a multiple of kernel_x_dim.

The host program copies b into the y tensor of the left column of PEs, with each PE getting the corresponding M/kernel_y_dim values. Each PE also gets a corresponding chunk of A, consisting of M/kernel_y_dim x N/kernel_x_dim elements. Similarly, the host program copies x into the upper row of PEs, with each PE getting N/kernel_x_dim values.

When the compute function is launched, the PEs in the top row begin sending their respective elements of x to their routers, along the color x_color. These PEs send the elements of x both to the SOUTH and back down their own RAMP. All other rows receive elements of x along x_color from the NORTH and transmit them down their RAMP, with all but the last row also transmitting the elements further SOUTH.

On all PEs, receiving a wavelet along x_color activates recv_x. This task is a wavelet-triggered task (WTT): the wavelet’s data is fed in as an argument to recv_x.

When a PE receives an element of x in the recv_x task, it increments its local y tensor by computing the corresponding piece of Ax. When a PE has received all corresponding elements of x along x_color, with each PE receiving N/kernel_x_dim values, it has finished computing its local contribution to y.

At this point, the local task reduce is activated. We use two colors in a checkerboard pattern to accumulate the partial y results from WEST to EAST. On the even columns, we use ax_color_1 to receive partial results from the WEST and ax_color_2 to send partial results to the EAST; on the odd columns, we use ax_color_2 to receive partial results from the WEST and ax_color_1 to send partial results to the EAST. We must use this checkerboard pattern because we cannot safely send and receive multiple wavelets on the same color with a fixed routing. In a future example, we will demonstrate the use of dynamic switching to update the color routing, which will allow you to use a single color.

The leftmost column of PEs has nothing to receive from the WEST, so these PEs only send their partial y results to the EAST. The remaining columns, upon receiving a partial y result from the WEST, increment their y tensors by the received values, and all but the rightmost column sends those values to the EAST The values in the rightmost column contain the final result y, with each PE containing M/kernel_y_dim elements.

Last, the host copies y from the right column of PEs, and checks that the result is correct.

layout.csl

// kernel dimensions
param kernel_x_dim: i16;
param kernel_y_dim: i16;

// total matrix dimensions
param M: i16;
param N: i16;

// Colors
const ax_color_1: color = @get_color(0); // sends/recvs partial result Ax EAST/WEST
const ax_color_2: color = @get_color(1); // sends/recvs partial result Ax EAST/WEST
const x_color:    color = @get_color(2); // sends/recvs elems x
const LAUNCH:     color = @get_color(8); // Color used by memcpy for kernel launch

// Task IDs
const recv_x_task_id: data_task_id  = @get_data_task_id(x_color); // Task ID for WTT
const exit_task_id:   local_task_id = @get_local_task_id(9);      // Task ID used by local task
const reduce_task_id: local_task_id = @get_local_task_id(10);     // Task ID used by local task

// This example uses kernel_x_dim x kernel_y_dim PEs
const memcpy = @import_module("<memcpy/get_params>", .{
  .width = kernel_x_dim,
  .height = kernel_y_dim,
  .LAUNCH = LAUNCH
});

layout {
  // PE coordinates are (column, row)
  @set_rectangle(kernel_x_dim, kernel_y_dim);

  // PE rectangle dimensions should evenly divide matrix dimensions
  @comptime_assert(M % kernel_y_dim == 0);
  @comptime_assert(N % kernel_x_dim == 0);

  const common_params = .{
    .M_per_PE = M / kernel_y_dim,
    .N_per_PE = N / kernel_x_dim,
    .kernel_x_dim = kernel_x_dim,
    .kernel_y_dim = kernel_y_dim,
    .x_color = x_color,
    .recv_x_task_id = recv_x_task_id,
    .exit_task_id = exit_task_id,
    .reduce_task_id = reduce_task_id
  };

  const even_params = @concat_structs(common_params, .{
    .send_east_color = ax_color_1, .recv_west_color = ax_color_2
  });

  const odd_params = @concat_structs(common_params, .{
    .send_east_color = ax_color_2, .recv_west_color = ax_color_1
  });

  for (@range(i16, kernel_x_dim)) |pe_x| {
    for (@range(i16, kernel_y_dim)) |pe_y| {
      if (pe_x % 2 == 0) {
        @set_tile_code(pe_x, pe_y, "pe_program.csl", @concat_structs(
          .{ .memcpy_params = memcpy.get_params(pe_x) }, even_params));
      } else {
        @set_tile_code(pe_x, pe_y, "pe_program.csl", @concat_structs(
          .{ .memcpy_params = memcpy.get_params(pe_x) }, odd_params));
      }
    }
  }

  // Create route values
  const RX_R_TX_RS = .{ .rx = .{RAMP},  .tx = .{RAMP, SOUTH} };
  const RX_N_TX_RS = .{ .rx = .{NORTH}, .tx = .{RAMP, SOUTH} };
  const RX_N_TX_R  = .{ .rx = .{NORTH}, .tx = .{RAMP} };

  const RX_W_TX_R  = .{ .rx = .{WEST},  .tx = .{RAMP} };
  const RX_R_TX_E  = .{ .rx = .{RAMP},  .tx = .{EAST} };

  for (@range(i16, kernel_x_dim)) |pe_x| {
    for (@range(i16, kernel_y_dim)) |pe_y| {
      if (pe_y == 0) {
        @set_color_config(pe_x, pe_y, x_color, .{ .routes = RX_R_TX_RS });
      } else if (pe_y == kernel_y_dim-1) {
        @set_color_config(pe_x, pe_y, x_color, .{ .routes = RX_N_TX_R  });
      } else {
        @set_color_config(pe_x, pe_y, x_color, .{ .routes = RX_N_TX_RS });
      }

      if (pe_x % 2 == 0) {
        @set_color_config(pe_x, pe_y, ax_color_1, .{ .routes = RX_R_TX_E });
        @set_color_config(pe_x, pe_y, ax_color_2, .{ .routes = RX_W_TX_R });
      } else {
        @set_color_config(pe_x, pe_y, ax_color_1, .{ .routes = RX_W_TX_R });
        @set_color_config(pe_x, pe_y, ax_color_2, .{ .routes = RX_R_TX_E });
      }
    }
  }

  // export symbol names
  @export_name("A", [*]f32, true);
  @export_name("x", [*]f32, true);
  @export_name("y", [*]f32, true);
  @export_name("compute", fn()void);
}

pe_program.csl

param memcpy_params: comptime_struct;

// Matrix dimensions
param M_per_PE: i16;
param N_per_PE: i16;

// Program rectangle dimensions
param kernel_x_dim: u16;
param kernel_y_dim: u16;

// Colors
param send_east_color: color; // sends partial result Ax EAST
param recv_west_color: color; // recvs partial result Ax WEST
param x_color:         color; // sends elems x SOUTH

// Task IDs
param recv_x_task_id: data_task_id;  // Task ID for data task recv_x, consumes x_color wlts
param exit_task_id:   local_task_id; // Task ID used by exit task to unblock cmd stream
param reduce_task_id: local_task_id; // Task ID used by reduce task

// memcpy module provides infrastructure for copying data
// and launching functions from the host
const sys_mod = @import_module("<memcpy/memcpy>", memcpy_params);

// layout module provides PE coordinates at runtime
const layout_mod = @import_module("<layout>");


// 48 kB of global memory contain A, x, b, y
var A: [M_per_PE*N_per_PE]f32; // A is stored column major
var x: [N_per_PE]f32;
var y: [M_per_PE]f32;

// DSDs for accessing A, x, y
// A_dsd accesses column of A
var A_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{M_per_PE} -> A[i] });
var x_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{N_per_PE} -> x[i] });
var y_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{M_per_PE} -> y[i] });

// ptrs to A, x, b, y will be advertised as symbols to host
var A_ptr: [*]f32 = &A;
var x_ptr: [*]f32 = &x;
var y_ptr: [*]f32 = &y;


fn is_top_row() bool {
  return (layout_mod.get_y_coord() == 0);
}

fn is_left_col() bool {
  return (layout_mod.get_x_coord() == 0);
}

fn is_right_col() bool {
  return (layout_mod.get_x_coord() == kernel_x_dim-1);
}


task reduce() void {

  const in_dsd  = @get_dsd(fabin_dsd, .{
                    .fabric_color = recv_west_color, .extent = M_per_PE,
                    .input_queue = @get_input_queue(1)
                  });

  const out_dsd = @get_dsd(fabout_dsd, .{
                    .fabric_color = send_east_color, .extent = M_per_PE,
                    .output_queue = @get_output_queue(2)
                  });

  // After fmovs is done, activate exit_task to unblock cmd_stream
  if (is_left_col()) {
    @fmovs(out_dsd, y_dsd, .{ .async = true, .activate = exit_task_id });
  } else if (is_right_col()) {
    @fadds(y_dsd, y_dsd, in_dsd, .{ .async = true, .activate = exit_task_id });
  } else {
    @fadds(out_dsd, y_dsd, in_dsd, .{ .async = true, .activate = exit_task_id });
  }
}


// Use to keep track of # of invocations of recv_x task
// when num_recv_x == N_per_PE, we are done receiving x elements
var num_recv_x: i16 = 0;

task recv_x(x_val: f32) void {
  @fmacs(y_dsd, y_dsd, A_dsd, x_val);
  A_dsd = @increment_dsd_offset(A_dsd, M_per_PE, f32);

  num_recv_x += 1;
  if (num_recv_x == N_per_PE) {
    @activate(reduce_task_id);
  }
}


// The top row sends x values along x_color to launch recv_x
fn compute() void {
  if (is_top_row()) {
    const send_x_dsd = @get_dsd(fabout_dsd, .{
                         .fabric_color = x_color, .extent = N_per_PE,
                         .output_queue = @get_output_queue(3)
                       });
    @fmovs(send_x_dsd, x_dsd, .{ .async = true });
  }
}


task exit_task() void {
  sys_mod.unblock_cmd_stream();
}


comptime {
  // When exit_task_id is activated, exit_task will execute
  @bind_local_task(exit_task, exit_task_id);

  // recv_x is wavelet-triggered task (WTT)
  // activated by receiving wavelets along color x_color,
  // which corresponds to recv_x_task_id
  @bind_data_task(recv_x, recv_x_task_id);

  // reduce is local task activated by ID reduce_task_ID
  @bind_local_task(reduce, reduce_task_id);

  @export_symbol(A_ptr, "A");
  @export_symbol(x_ptr, "x");
  @export_symbol(y_ptr, "y");
  @export_symbol(compute);
  @rpc(@get_data_task_id(sys_mod.LAUNCH));
}

run.py

#!/usr/bin/env cs_python

import argparse
import json
import numpy as np

from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType, MemcpyOrder # pylint: disable=no-name-in-module

# Read arguments
parser = argparse.ArgumentParser()
parser.add_argument('--name', help="the test compile output dir")
parser.add_argument('--cmaddr', help="IP:port for CS system")
args = parser.parse_args()

# Get matrix dimensions from compile metadata
with open(f"{args.name}/out.json", encoding='utf-8') as json_file:
  compile_data = json.load(json_file)

# Matrix dimensions
N = int(compile_data['params']['N']) # columns
M = int(compile_data['params']['M']) # rows

# PE grid dimensions
kernel_x_dim = int(compile_data['params']['kernel_x_dim'])
kernel_y_dim = int(compile_data['params']['kernel_y_dim'])

# Construct A, x, b
A = np.arange(M*N, dtype=np.float32).reshape(M,N)
x = np.full(shape=N, fill_value=1.0, dtype=np.float32)
b = np.full(shape=M, fill_value=2.0, dtype=np.float32)

# Calculate expected y
y_expected = A@x + b

# Size of N dimension on each PE
N_per_PE = N // kernel_x_dim
M_per_PE = M // kernel_y_dim

# Construct a runner using SdkRuntime
runner = SdkRuntime(args.name, cmaddr=args.cmaddr)

# Get symbols for A, x, y on device
A_symbol = runner.get_id('A')
x_symbol = runner.get_id('x')
y_symbol = runner.get_id('y')

# Load and run the program
runner.load()
runner.run()

# Copy b into y of PEs (0, 0) to (0, kernel_y_dim-1)
runner.memcpy_h2d(y_symbol, b, 0, 0, 1, kernel_y_dim, M_per_PE, streaming=False,
  order=MemcpyOrder.ROW_MAJOR, data_type=MemcpyDataType.MEMCPY_32BIT, nonblock=False)

# Copy chunks of A into all PEs
# Each chunk on each PE is stored column major
A_prepared = A.reshape(kernel_y_dim, M_per_PE, kernel_x_dim, N_per_PE).transpose(0, 2, 3, 1).ravel()
runner.memcpy_h2d(A_symbol, A_prepared, 0, 0, kernel_x_dim, kernel_y_dim, M_per_PE*N_per_PE,
  streaming=False, order=MemcpyOrder.ROW_MAJOR, data_type=MemcpyDataType.MEMCPY_32BIT,
  nonblock=False)

# Copy x into PEs (0, 0) and (kernel_x_dim-1, 0)
# PE (0, 0) gets first N/2 elements; PE (1, 0) gets last N/2 elements
runner.memcpy_h2d(x_symbol, x, 0, 0, kernel_x_dim, 1, N_per_PE, streaming=False,
  order=MemcpyOrder.ROW_MAJOR, data_type=MemcpyDataType.MEMCPY_32BIT, nonblock=False)

# Launch the compute function on device
runner.launch('compute', nonblock=False)

# Copy y back from PEs (kernel_x_dim-1, 0) and (kernel_x_dim-1, kernel_y_dim-1)
y_result = np.zeros([M], dtype=np.float32)
runner.memcpy_d2h(y_result, y_symbol, kernel_x_dim-1, 0, 1, kernel_y_dim, M_per_PE, streaming=False,
  order=MemcpyOrder.ROW_MAJOR, data_type=MemcpyDataType.MEMCPY_32BIT, nonblock=False)

# Stop the program
runner.stop()

# Ensure that the result matches our expectation
np.testing.assert_allclose(y_result, y_expected)
print("SUCCESS!")

commands.sh

#!/usr/bin/env bash

set -e

cslc ./layout.csl --fabric-dims=11,5 \
--fabric-offsets=4,1 --params=kernel_x_dim:4,kernel_y_dim:3,M:6,N:8 \
-o out --memcpy --channels 1
cs_python run.py --name out