const std = @import("std");
const builtin = @import("builtin");
pub fn suggestVectorSizeForCpu(comptime T: type, comptime cpu: std.Target.Cpu) ?usize {
const element_bit_size = @max(8, std.math.ceilPowerOfTwo(u16, @bitSizeOf(T)) catch unreachable);
const vector_bit_size: u16 = blk: {
if (cpu.arch.isX86()) {
if (T == bool and std.Target.x86.featureSetHas(.prefer_mask_registers)) return 64;
if (std.Target.x86.featureSetHas(cpu.features, .avx512f) and !std.Target.x86.featureSetHasAny(cpu.features, .{ .prefer_256_bit, .prefer_128_bit })) break :blk 512;
if (std.Target.x86.featureSetHasAny(cpu.features, .{ .prefer_256_bit, .avx2 }) and !std.Target.x86.featureSetHas(cpu.features, .prefer_128_bit)) break :blk 256;
if (std.Target.x86.featureSetHas(cpu.features, .sse)) break :blk 128;
if (std.Target.x86.featureSetHasAny(cpu.features, .{ .mmx, .@"3dnow" })) break :blk 64;
} else if (cpu.arch.isARM()) {
if (std.Target.arm.featureSetHas(cpu.features, .neon)) break :blk 128;
} else if (cpu.arch.isAARCH64()) {
if (std.Target.aarch64.featureSetHas(cpu.features, .sve)) break :blk 128;
if (std.Target.aarch64.featureSetHas(cpu.features, .neon)) break :blk 128;
} else if (cpu.arch.isPPC() or cpu.arch.isPPC64()) {
if (std.Target.powerpc.featureSetHas(cpu.features, .altivec)) break :blk 128;
} else if (cpu.arch.isMIPS()) {
if (std.Target.mips.featureSetHas(cpu.features, .msa)) break :blk 128;
if (std.Target.mips.featureSetHas(cpu.features, std.Target.mips.Feature.mips3d)) break :blk 64;
} else if (cpu.arch.isRISCV()) {
if (std.Target.riscv.featureSetHas(cpu.features, .v)) break :blk 128;
} else if (cpu.arch.isSPARC()) {
if (std.Target.sparc.featureSetHasAny(cpu.features, .{ .vis, .vis2, .vis3 })) break :blk 64;
}
return null;
};
if (vector_bit_size <= element_bit_size) return null;
return @divExact(vector_bit_size, element_bit_size);
}
pub fn suggestVectorSize(comptime T: type) ?usize {
return suggestVectorSizeForCpu(T, builtin.cpu);
}
test "suggestVectorSizeForCpu works with signed and unsigned values" {
comptime var cpu = std.Target.Cpu.baseline(std.Target.Cpu.Arch.x86_64);
comptime cpu.features.addFeature(@enumToInt(std.Target.x86.Feature.avx512f));
const signed_integer_size = suggestVectorSizeForCpu(i32, cpu).?;
const unsigned_integer_size = suggestVectorSizeForCpu(u32, cpu).?;
try std.testing.expectEqual(@as(usize, 16), unsigned_integer_size);
try std.testing.expectEqual(@as(usize, 16), signed_integer_size);
}
fn vectorLength(comptime VectorType: type) comptime_int {
return switch (@typeInfo(VectorType)) {
.Vector => |info| info.len,
.Array => |info| info.len,
else => @compileError("Invalid type " ++ @typeName(VectorType)),
};
}
pub fn VectorIndex(comptime VectorType: type) type {
return std.math.IntFittingRange(0, vectorLength(VectorType) - 1);
}
pub fn VectorCount(comptime VectorType: type) type {
return std.math.IntFittingRange(0, vectorLength(VectorType));
}
pub fn iota(comptime T: type, comptime len: usize) @Vector(len, T) {
var out: [len]T = undefined;
for (out) |*element, i| {
element.* = switch (@typeInfo(T)) {
.Int => @intCast(T, i),
.Float => @intToFloat(T, i),
else => @compileError("Can't use type " ++ @typeName(T) ++ " in iota."),
};
}
return @as(@Vector(len, T), out);
}
pub fn repeat(comptime len: usize, vec: anytype) @Vector(len, std.meta.Child(@TypeOf(vec))) {
const Child = std.meta.Child(@TypeOf(vec));
return @shuffle(Child, vec, undefined, iota(i32, len) % @splat(len, @intCast(i32, vectorLength(@TypeOf(vec)))));
}
pub fn join(a: anytype, b: anytype) @Vector(vectorLength(@TypeOf(a)) + vectorLength(@TypeOf(b)), std.meta.Child(@TypeOf(a))) {
const Child = std.meta.Child(@TypeOf(a));
const a_len = vectorLength(@TypeOf(a));
const b_len = vectorLength(@TypeOf(b));
return @shuffle(Child, a, b, @as([a_len]i32, iota(i32, a_len)) ++ @as([b_len]i32, ~iota(i32, b_len)));
}
pub fn interlace(vecs: anytype) @Vector(vectorLength(@TypeOf(vecs[0])) * vecs.len, std.meta.Child(@TypeOf(vecs[0]))) {
comptime if (builtin.cpu.arch.isMIPS()) @compileError("TODO: Find out why interlace() doesn't work on MIPS");
const VecType = @TypeOf(vecs[0]);
const vecs_arr = @as([vecs.len]VecType, vecs);
const Child = std.meta.Child(@TypeOf(vecs_arr[0]));
if (vecs_arr.len == 1) return vecs_arr[0];
const a_vec_count = (1 + vecs_arr.len) >> 1;
const b_vec_count = vecs_arr.len >> 1;
const a = interlace(@ptrCast(*const [a_vec_count]VecType, vecs_arr[0..a_vec_count]).*);
const b = interlace(@ptrCast(*const [b_vec_count]VecType, vecs_arr[a_vec_count..]).*);
const a_len = vectorLength(@TypeOf(a));
const b_len = vectorLength(@TypeOf(b));
const len = a_len + b_len;
const indices = comptime blk: {
const count_up = iota(i32, len);
const cycle = @divFloor(count_up, @splat(len, @intCast(i32, vecs_arr.len)));
const select_mask = repeat(len, join(@splat(a_vec_count, true), @splat(b_vec_count, false)));
const a_indices = count_up - cycle * @splat(len, @intCast(i32, b_vec_count));
const b_indices = shiftElementsRight(count_up - cycle * @splat(len, @intCast(i32, a_vec_count)), a_vec_count, 0);
break :blk @select(i32, select_mask, a_indices, ~b_indices);
};
return @shuffle(Child, a, b, indices);
}
pub fn deinterlace(
comptime vec_count: usize,
interlaced: anytype,
) [vec_count]@Vector(
vectorLength(@TypeOf(interlaced)) / vec_count,
std.meta.Child(@TypeOf(interlaced)),
) {
const vec_len = vectorLength(@TypeOf(interlaced)) / vec_count;
const Child = std.meta.Child(@TypeOf(interlaced));
var out: [vec_count]@Vector(vec_len, Child) = undefined;
comptime var i: usize = 0;
inline while (i < out.len) : (i += 1) {
const indices = comptime iota(i32, vec_len) * @splat(vec_len, @intCast(i32, vec_count)) + @splat(vec_len, @intCast(i32, i));
out[i] = @shuffle(Child, interlaced, undefined, indices);
}
return out;
}
pub fn extract(
vec: anytype,
comptime first: VectorIndex(@TypeOf(vec)),
comptime count: VectorCount(@TypeOf(vec)),
) @Vector(count, std.meta.Child(@TypeOf(vec))) {
const Child = std.meta.Child(@TypeOf(vec));
const len = vectorLength(@TypeOf(vec));
std.debug.assert(@intCast(comptime_int, first) + @intCast(comptime_int, count) <= len);
return @shuffle(Child, vec, undefined, iota(i32, count) + @splat(count, @intCast(i32, first)));
}
test "vector patterns" {
if ((builtin.zig_backend == .stage1 or builtin.zig_backend == .stage2_llvm) and
builtin.cpu.arch == .aarch64)
{
return error.SkipZigTest;
}
const base = @Vector(4, u32){ 10, 20, 30, 40 };
const other_base = @Vector(4, u32){ 55, 66, 77, 88 };
const small_bases = [5]@Vector(2, u8){
@Vector(2, u8){ 0, 1 },
@Vector(2, u8){ 2, 3 },
@Vector(2, u8){ 4, 5 },
@Vector(2, u8){ 6, 7 },
@Vector(2, u8){ 8, 9 },
};
try std.testing.expectEqual([6]u32{ 10, 20, 30, 40, 10, 20 }, repeat(6, base));
try std.testing.expectEqual([8]u32{ 10, 20, 30, 40, 55, 66, 77, 88 }, join(base, other_base));
try std.testing.expectEqual([2]u32{ 20, 30 }, extract(base, 1, 2));
if (comptime !builtin.cpu.arch.isMIPS()) {
try std.testing.expectEqual([8]u32{ 10, 55, 20, 66, 30, 77, 40, 88 }, interlace(.{ base, other_base }));
const small_braid = interlace(small_bases);
try std.testing.expectEqual([10]u8{ 0, 2, 4, 6, 8, 1, 3, 5, 7, 9 }, small_braid);
try std.testing.expectEqual(small_bases, deinterlace(small_bases.len, small_braid));
}
}
pub fn mergeShift(a: anytype, b: anytype, comptime shift: VectorCount(@TypeOf(a, b))) @TypeOf(a, b) {
const len = vectorLength(@TypeOf(a, b));
return extract(join(a, b), shift, len);
}
pub fn shiftElementsRight(vec: anytype, comptime amount: VectorCount(@TypeOf(vec)), shift_in: std.meta.Child(@TypeOf(vec))) @TypeOf(vec) {
const len = vectorLength(@TypeOf(vec));
return mergeShift(@splat(len, shift_in), vec, len - amount);
}
pub fn shiftElementsLeft(vec: anytype, comptime amount: VectorCount(@TypeOf(vec)), shift_in: std.meta.Child(@TypeOf(vec))) @TypeOf(vec) {
const len = vectorLength(@TypeOf(vec));
return mergeShift(vec, @splat(len, shift_in), amount);
}
pub fn rotateElementsLeft(vec: anytype, comptime amount: VectorCount(@TypeOf(vec))) @TypeOf(vec) {
return mergeShift(vec, vec, amount);
}
pub fn rotateElementsRight(vec: anytype, comptime amount: VectorCount(@TypeOf(vec))) @TypeOf(vec) {
return rotateElementsLeft(vec, vectorLength(@TypeOf(vec)) - amount);
}
pub fn reverseOrder(vec: anytype) @TypeOf(vec) {
const Child = std.meta.Child(@TypeOf(vec));
const len = vectorLength(@TypeOf(vec));
return @shuffle(Child, vec, undefined, @splat(len, @intCast(i32, len) - 1) - iota(i32, len));
}
test "vector shifting" {
const base = @Vector(4, u32){ 10, 20, 30, 40 };
try std.testing.expectEqual([4]u32{ 30, 40, 999, 999 }, shiftElementsLeft(base, 2, 999));
try std.testing.expectEqual([4]u32{ 999, 999, 10, 20 }, shiftElementsRight(base, 2, 999));
try std.testing.expectEqual([4]u32{ 20, 30, 40, 10 }, rotateElementsLeft(base, 1));
try std.testing.expectEqual([4]u32{ 40, 10, 20, 30 }, rotateElementsRight(base, 1));
try std.testing.expectEqual([4]u32{ 40, 30, 20, 10 }, reverseOrder(base));
}
pub fn firstTrue(vec: anytype) ?VectorIndex(@TypeOf(vec)) {
const len = vectorLength(@TypeOf(vec));
const IndexInt = VectorIndex(@TypeOf(vec));
if (!@reduce(.Or, vec)) {
return null;
}
const indices = @select(IndexInt, vec, iota(IndexInt, len), @splat(len, ~@as(IndexInt, 0)));
return @reduce(.Min, indices);
}
pub fn lastTrue(vec: anytype) ?VectorIndex(@TypeOf(vec)) {
const len = vectorLength(@TypeOf(vec));
const IndexInt = VectorIndex(@TypeOf(vec));
if (!@reduce(.Or, vec)) {
return null;
}
const indices = @select(IndexInt, vec, iota(IndexInt, len), @splat(len, @as(IndexInt, 0)));
return @reduce(.Max, indices);
}
pub fn countTrues(vec: anytype) VectorCount(@TypeOf(vec)) {
const len = vectorLength(@TypeOf(vec));
const CountIntType = VectorCount(@TypeOf(vec));
const one_if_true = @select(CountIntType, vec, @splat(len, @as(CountIntType, 1)), @splat(len, @as(CountIntType, 0)));
return @reduce(.Add, one_if_true);
}
pub fn firstIndexOfValue(vec: anytype, value: std.meta.Child(@TypeOf(vec))) ?VectorIndex(@TypeOf(vec)) {
const len = vectorLength(@TypeOf(vec));
return firstTrue(vec == @splat(len, value));
}
pub fn lastIndexOfValue(vec: anytype, value: std.meta.Child(@TypeOf(vec))) ?VectorIndex(@TypeOf(vec)) {
const len = vectorLength(@TypeOf(vec));
return lastTrue(vec == @splat(len, value));
}
pub fn countElementsWithValue(vec: anytype, value: std.meta.Child(@TypeOf(vec))) VectorCount(@TypeOf(vec)) {
const len = vectorLength(@TypeOf(vec));
return countTrues(vec == @splat(len, value));
}
test "vector searching" {
const base = @Vector(8, u32){ 6, 4, 7, 4, 4, 2, 3, 7 };
try std.testing.expectEqual(@as(?u3, 1), firstIndexOfValue(base, 4));
try std.testing.expectEqual(@as(?u3, 4), lastIndexOfValue(base, 4));
try std.testing.expectEqual(@as(?u3, null), lastIndexOfValue(base, 99));
try std.testing.expectEqual(@as(u4, 3), countElementsWithValue(base, 4));
}
pub fn prefixScanWithFunc(
comptime hop: isize,
vec: anytype,
comptime ErrorType: type,
comptime func: fn (@TypeOf(vec), @TypeOf(vec)) if (ErrorType == void) @TypeOf(vec) else ErrorType!@TypeOf(vec),
comptime identity: std.meta.Child(@TypeOf(vec)),
) if (ErrorType == void) @TypeOf(vec) else ErrorType!@TypeOf(vec) {
comptime if (builtin.cpu.arch.isMIPS()) @compileError("TODO: Find out why prefixScan doesn't work on MIPS");
const len = vectorLength(@TypeOf(vec));
if (hop == 0) @compileError("hop can not be 0; you'd be going nowhere forever!");
const abs_hop = if (hop < 0) -hop else hop;
var acc = vec;
comptime var i = 0;
inline while ((abs_hop << i) < len) : (i += 1) {
const shifted = if (hop < 0) shiftElementsLeft(acc, abs_hop << i, identity) else shiftElementsRight(acc, abs_hop << i, identity);
acc = if (ErrorType == void) func(acc, shifted) else try func(acc, shifted);
}
return acc;
}
pub fn prefixScan(comptime op: std.builtin.ReduceOp, comptime hop: isize, vec: anytype) @TypeOf(vec) {
const VecType = @TypeOf(vec);
const Child = std.meta.Child(VecType);
const len = vectorLength(VecType);
const identity = comptime switch (@typeInfo(Child)) {
.Bool => switch (op) {
.Or, .Xor => false,
.And => true,
else => @compileError("Invalid prefixScan operation " ++ @tagName(op) ++ " for vector of booleans."),
},
.Int => switch (op) {
.Max => std.math.minInt(Child),
.Add, .Or, .Xor => 0,
.Mul => 1,
.And, .Min => std.math.maxInt(Child),
},
.Float => switch (op) {
.Max => -std.math.inf(Child),
.Add => 0,
.Mul => 1,
.Min => std.math.inf(Child),
else => @compileError("Invalid prefixScan operation " ++ @tagName(op) ++ " for vector of floats."),
},
else => @compileError("Invalid type " ++ @typeName(VecType) ++ " for prefixScan."),
};
const fn_container = struct {
fn opFn(a: VecType, b: VecType) VecType {
return if (Child == bool) switch (op) {
.And => @select(bool, a, b, @splat(len, false)),
.Or => @select(bool, a, @splat(len, true), b),
.Xor => a != b,
else => unreachable,
} else switch (op) {
.And => a & b,
.Or => a | b,
.Xor => a ^ b,
.Add => a + b,
.Mul => a * b,
.Min => @min(a, b),
.Max => @max(a, b),
};
}
};
return prefixScanWithFunc(hop, vec, void, fn_container.opFn, identity);
}
test "vector prefix scan" {
if (comptime builtin.cpu.arch.isMIPS()) {
return error.SkipZigTest;
}
if (builtin.zig_backend == .stage1 or builtin.zig_backend == .stage2_llvm) {
return error.SkipZigTest;
}
const int_base = @Vector(4, i32){ 11, 23, 9, -21 };
const float_base = @Vector(4, f32){ 2, 0.5, -10, 6.54321 };
const bool_base = @Vector(4, bool){ true, false, true, false };
try std.testing.expectEqual(iota(u8, 32) + @splat(32, @as(u8, 1)), prefixScan(.Add, 1, @splat(32, @as(u8, 1))));
try std.testing.expectEqual(@Vector(4, i32){ 11, 3, 1, 1 }, prefixScan(.And, 1, int_base));
try std.testing.expectEqual(@Vector(4, i32){ 11, 31, 31, -1 }, prefixScan(.Or, 1, int_base));
try std.testing.expectEqual(@Vector(4, i32){ 11, 28, 21, -2 }, prefixScan(.Xor, 1, int_base));
try std.testing.expectEqual(@Vector(4, i32){ 11, 34, 43, 22 }, prefixScan(.Add, 1, int_base));
try std.testing.expectEqual(@Vector(4, i32){ 11, 253, 2277, -47817 }, prefixScan(.Mul, 1, int_base));
try std.testing.expectEqual(@Vector(4, i32){ 11, 11, 9, -21 }, prefixScan(.Min, 1, int_base));
try std.testing.expectEqual(@Vector(4, i32){ 11, 23, 23, 23 }, prefixScan(.Max, 1, int_base));
try std.testing.expectEqual(@Vector(4, f32){ 2, 0.5, -10, -10 }, prefixScan(.Min, 1, float_base));
try std.testing.expectEqual(@Vector(4, f32){ 2, 2, 2, 6.54321 }, prefixScan(.Max, 1, float_base));
try std.testing.expectEqual(@Vector(4, bool){ true, true, false, false }, prefixScan(.Xor, 1, bool_base));
try std.testing.expectEqual(@Vector(4, bool){ true, true, true, true }, prefixScan(.Or, 1, bool_base));
try std.testing.expectEqual(@Vector(4, bool){ true, false, false, false }, prefixScan(.And, 1, bool_base));
try std.testing.expectEqual(@Vector(4, i32){ 11, 23, 20, 2 }, prefixScan(.Add, 2, int_base));
try std.testing.expectEqual(@Vector(4, i32){ 22, 11, -12, -21 }, prefixScan(.Add, -1, int_base));
try std.testing.expectEqual(@Vector(4, i32){ 11, 23, 9, -10 }, prefixScan(.Add, 3, int_base));
}