python/skewb_solver.py
DomNomNomVR 028fb0bc3a dump
2025-04-14 15:58:38 +12:00

751 lines
22 KiB
Python

from math import sqrt
import shelve
from collections import deque
from dataclasses import dataclass, replace
from functools import reduce
import random
from typing import TYPE_CHECKING, Callable, Counter, Dict, List, Literal, Set, Tuple
from functools import lru_cache
import pytest
print()
# twisting the bottom, opposing this color clockwise when this color is facing away
Axis = Literal["R", "B", "O", "G"]
axes: tuple[Axis, ...] = Axis.__args__
# how many clockwise twists along an axis has been twisted after being flush with the top or bottom.
CornerRotation = Literal[0, 1, 2]
LowerAntiAxis = Literal["R", "B", "O", "G", "r", "b", "o", "g"]
lower_anti_axes: tuple[LowerAntiAxis, ...] = LowerAntiAxis.__args__
to_anticlockwise: dict[Axis, LowerAntiAxis] = {
"R": "r",
"B": "b",
"O": "o",
"G": "g",
}
to_clockwise: dict[LowerAntiAxis, Axis] = {v: k for k, v in to_anticlockwise.items()}
to_opposite: dict[LowerAntiAxis, LowerAntiAxis] = {**to_anticlockwise, **to_clockwise}
@dataclass(frozen=True)
class Corner:
col: Axis
rot: CornerRotation
def __repr__(self) -> str:
return f"{self.col}{self.rot}"
MidRotatation = Literal["Y", "RG", "RB"]
@dataclass(frozen=True)
class Middle:
col: Literal["R", "B", "O", "G", "Y"]
rot: MidRotatation
def __repr__(self) -> str:
return f"{self.col}{self.rot}"
@dataclass(frozen=True)
class Skewb:
"""Represents a Rubics cube variant sold as 'Qiyi Twisty Skewb'."""
# order: vertically below Axis primary color
# "R", "B", "O", "G"
top: Tuple[Corner, Corner, Corner, Corner]
bot: Tuple[Corner, Corner, Corner, Corner]
# order: RB BO OG GR Y
mids: Tuple[Middle, Middle, Middle, Middle, Middle]
# def __post_init__(self):
# """Check basic constraints such as the correct number of colors"""
def assert_valid(self):
assert Counter(
corner.col for corners in [self.top, self.bot] for corner in corners
) == {col: 2 for col in axes}
assert Counter(mid.col for mid in self.mids) == Counter("RBOGY")
# forbidden rotations
assert self.mids[0].rot != "RB"
assert self.mids[1].rot != "RG"
assert self.mids[2].rot != "RB"
assert self.mids[3].rot != "RG"
assert self.mids[4].col == "Y" or self.mids[4].rot != "Y"
R0 = Corner("R", 0)
R1 = Corner("R", 1)
R2 = Corner("R", 2)
B0 = Corner("B", 0)
B1 = Corner("B", 1)
B2 = Corner("B", 2)
O0 = Corner("O", 0)
O1 = Corner("O", 1)
O2 = Corner("O", 2)
G0 = Corner("G", 0)
G1 = Corner("G", 1)
G2 = Corner("G", 2)
RY = Middle("R", "Y")
RRG = Middle("R", "RG")
RRB = Middle("R", "RB")
BY = Middle("B", "Y")
BRG = Middle("B", "RG")
BRB = Middle("B", "RB")
OY = Middle("O", "Y")
ORG = Middle("O", "RG")
ORB = Middle("O", "RB")
GY = Middle("G", "Y")
GRG = Middle("G", "RG")
GRB = Middle("G", "RB")
YY = Middle("Y", "Y")
YRG = Middle("Y", "RG")
YRB = Middle("Y", "RB")
def rotate_side_mid_rot_about_W(m: Middle) -> Middle:
return Middle(
col=m.col,
rot="Y" if m.rot == "Y" else ("RB" if m.rot == "RG" else "RG"),
)
def rotate_bot_mid_rot_about_W(m: Middle) -> Middle:
if m.col == "Y":
return Middle("Y", "Y")
return Middle(col=m.col, rot="RB" if m.rot == "RG" else "RG")
def rotate_everything_about_W(s: Skewb) -> Skewb:
"""Clockwise rotation when looking down upon white."""
return Skewb(
top=(s.top[-1],) + s.top[:-1],
bot=(s.bot[-1],) + s.bot[:-1],
mids=(
rotate_side_mid_rot_about_W(s.mids[3]),
rotate_side_mid_rot_about_W(s.mids[0]),
rotate_side_mid_rot_about_W(s.mids[1]),
rotate_side_mid_rot_about_W(s.mids[2]),
rotate_side_mid_rot_about_W(s.mids[4]),
),
)
desk_start = Skewb(
top=(R0, B0, O0, G2),
bot=(B0, O0, G1, R0),
mids=(BY, OY, RRG, GRB, YY),
)
solved_skewb = Skewb(
top=(O0, B0, R0, G0),
bot=(B0, R0, G0, O0),
mids=(BY, RY, GY, OY, YY),
)
@pytest.mark.parametrize("s", [desk_start, solved_skewb])
def test_rotate_everything_about_W(s: Skewb):
ss = [s]
for i in range(4):
ss.append(rotate_everything_about_W(ss[-1]))
assert ss[0] == ss[-1]
assert all(ss[0] != q for q in ss[1:-1])
def test_axes():
assert axes[0] == "R"
type CornerRotPermutation = dict[CornerRotation, CornerRotation]
BOT_LEFT_TO_TOP: CornerRotPermutation = {0: 2, 1: 0, 2: 1}
TOP_TO_BOT_RIGHT: CornerRotPermutation = {0: 2, 1: 0, 2: 1}
ROTATE_CORNER_CLOCKWISE: CornerRotPermutation = {0: 1, 1: 2, 2: 0}
BOT_RIGHT_TO_BOT_LEFT: CornerRotPermutation = {0: 2, 1: 0, 2: 1}
@pytest.mark.parametrize(
"p",
[BOT_LEFT_TO_TOP, TOP_TO_BOT_RIGHT, ROTATE_CORNER_CLOCKWISE, BOT_RIGHT_TO_BOT_LEFT],
)
def test_rotation_permutations(p: CornerRotPermutation):
assert set(p.keys()) == {0, 1, 2}
assert set(p.values()) == {0, 1, 2}
MID_DIR_INCREMENT: dict[MidRotatation, MidRotatation] = {
"RG": "Y",
"Y": "RB",
"RB": "RG",
}
def clockwise_twist(s: Skewb, twist: Axis) -> Skewb:
rot_before, rot_after = {"R": (0, 0), "B": (3, 1), "O": (2, 2), "G": (1, 3)}[twist]
for _ in range(rot_before):
s = rotate_everything_about_W(s)
s = Skewb(
top=(
s.top[0],
s.top[1],
Corner((c := s.bot[3]).col, BOT_LEFT_TO_TOP[c.rot]),
s.top[3],
),
bot=(
s.bot[0],
Corner((c := s.top[2]).col, TOP_TO_BOT_RIGHT[c.rot]),
Corner((c := s.bot[2]).col, ROTATE_CORNER_CLOCKWISE[c.rot]),
Corner((c := s.bot[1]).col, BOT_RIGHT_TO_BOT_LEFT[c.rot]),
),
mids=(
s.mids[0],
Middle(
(m := s.mids[2]).col,
"Y"
if m.col == "Y"
else MID_DIR_INCREMENT[m.rot], # ("Y" if m.rot == "RG" else "RB"),
),
Middle(
(m := s.mids[4]).col,
"Y"
if m.col == "Y"
else MID_DIR_INCREMENT[m.rot], # ("RG" if m.rot == "RB" else "Y"),
),
s.mids[3],
Middle(
(m := s.mids[1]).col,
"Y" if m.col == "Y" else ("RB" if m.rot == "Y" else "RG"),
),
),
)
for _ in range(rot_after):
s = rotate_everything_about_W(s)
return s
def anticlockwise_twist(s: Skewb, twist: Axis) -> Skewb:
return clockwise_twist(clockwise_twist(s, twist), twist)
@pytest.mark.parametrize("start", [desk_start, solved_skewb])
@pytest.mark.parametrize("axis", axes)
def test_clockwise_twist(start: Skewb, axis: Axis):
assert anticlockwise_twist(clockwise_twist(start, axis), axis) == start
assert clockwise_twist(anticlockwise_twist(start, axis), axis) == start
ss = [start]
for i in range(3):
ss.append(clockwise_twist(ss[-1], axis))
assert len(ss) == 4
assert ss[0] == ss[-1]
assert all([ss[0] != q for q in ss[1:-1]])
for s in ss:
assert anticlockwise_twist(clockwise_twist(s, axis), axis) == s
assert clockwise_twist(anticlockwise_twist(s, axis), axis) == s
def test_clockwise_twist_simple():
assert clockwise_twist(desk_start, "R") == Skewb(
top=(
Corner("R", 0),
Corner("B", 0),
Corner("R", 2),
Corner("G", 2),
),
bot=(
Corner("B", 0),
Corner("O", 2),
Corner("G", 2),
Corner("O", 2),
),
mids=(
Middle("B", "Y"),
Middle("R", "Y"),
Middle("Y", "Y"),
Middle("G", "RB"),
Middle("O", "RB"),
),
)
def test_clockwise_twist_simple2():
assert clockwise_twist(clockwise_twist(desk_start, "R"), "R") == Skewb(
top=(
Corner("R", 0),
Corner("B", 0),
Corner("O", 1),
Corner("G", 2),
),
bot=(
Corner("B", 0),
Corner("R", 1),
Corner("G", 0),
Corner("O", 1),
),
mids=(
Middle("B", "Y"),
Middle("Y", "Y"),
Middle("O", "RG"),
Middle("G", "RB"),
Middle("R", "RB"),
),
)
def test_clockwise_twist_simple3():
start = Skewb(
top=(
Corner(col="R", rot=0),
Corner(col="B", rot=0),
Corner(col="O", rot=0),
Corner(col="G", rot=0),
),
bot=(
Corner(col="G", rot=0),
Corner(col="R", rot=1),
Corner(col="B", rot=2),
Corner(col="O", rot=1),
),
mids=(
Middle(col="Y", rot="Y"),
Middle(col="O", rot="Y"),
Middle(col="R", rot="RG"),
Middle(col="B", rot="RB"),
Middle(col="G", rot="RG"),
),
)
start.assert_valid()
end = Skewb(
top=(
Corner(col="R", rot=0),
Corner(col="B", rot=1),
Corner(col="O", rot=0),
Corner(col="G", rot=0),
),
bot=(
Corner(col="B", rot=2),
Corner(col="R", rot=2),
Corner(col="G", rot=2),
Corner(col="O", rot=1),
),
mids=(
Middle(col="O", rot="RG"),
Middle(col="G", rot="RB"),
Middle(col="R", rot="RG"),
Middle(col="B", rot="RB"),
Middle(col="Y", rot="Y"),
),
)
end.assert_valid()
assert clockwise_twist(start, "G") == end
def apply_twist(start: Skewb, twist: LowerAntiAxis) -> Skewb:
if twist in axes:
return clockwise_twist(start, twist)
return anticlockwise_twist(start, to_clockwise[twist])
def apply_opposite(s: Skewb, twist: LowerAntiAxis) -> Skewb:
return apply_twist(s, to_opposite[twist])
def apply_twists(start: Skewb, twists: list[LowerAntiAxis]) -> Skewb:
return reduce(apply_twist, twists, start)
def instructions(twists: list[Axis]) -> str:
out = ""
axis = "R"
for twist in twists:
while axis != twist:
axis = axes[(1 + axes.index(axis)) % 4]
out += "L"
out += "."
out = out.replace("..", ":")
out = out.replace("LLL", "J")
return out
def breadth_first_search(
start: Skewb, is_end: Callable[[Skewb], bool], max_steps: int = 2000000
) -> list[LowerAntiAxis] | None:
start.assert_valid()
if is_end(start):
return []
q = deque([start])
# what action got us to this point
skewb_to_twist: dict[Skewb, LowerAntiAxis | None] = {skewb: None for skewb in q}
def get_path(end: Skewb) -> list[LowerAntiAxis]:
out = []
s = end
while twist := skewb_to_twist[s]:
out.append(twist)
s = apply_opposite(s, twist)
out.reverse()
return out
while q and max_steps > 0:
max_steps -= 1
if max_steps % 1000 == 0:
print(".", end="", flush=True)
parent = q.popleft()
for twist in lower_anti_axes:
child = apply_twist(parent, twist)
if child in skewb_to_twist:
continue
skewb_to_twist[child] = twist
if is_end(child):
return get_path(child)
q.append(child)
return None
def test_breadth_first_search():
for twist in lower_anti_axes:
assert breadth_first_search(
apply_opposite(solved_skewb, twist), is_end=lambda s: s == solved_skewb
) == [twist]
def print_path(path: list[Axis]):
x = start
print(f"S -> {x}")
for twist in path:
x = clockwise_twist(x, twist)
print(f"{twist} -> {x}")
def bidirectional_search(
start: Skewb, max_steps: int, end: Skewb = solved_skewb
) -> list[LowerAntiAxis] | None:
start.assert_valid()
q = deque([start])
q2 = deque([end])
# what action got us to this point
skewb_to_twist: dict[Skewb, LowerAntiAxis | None] = {skewb: None for skewb in q}
skewb_to_twist2: dict[Skewb, LowerAntiAxis | None] = {skewb: None for skewb in q2}
def get_path(meet: Skewb) -> list[LowerAntiAxis]:
path = []
s = meet
while twist := skewb_to_twist[s]:
path.append(twist)
s = apply_opposite(s, twist)
path.reverse()
s = meet
while twist := skewb_to_twist2[s]:
path.append(twist)
s = apply_twist(s, twist)
return path
def instructions(twists: list[Axis]) -> str:
out = ""
axis = "R"
for twist in twists:
while axis != twist:
axis = axes[(1 + axes.index(axis)) % 4]
out += "L"
out += "."
out = out.replace("..", ":")
out = out.replace("LLL", "J")
return out
def on_meet(meet: Skewb) -> list[LowerAntiAxis]:
path = get_path(meet)
assert apply_twists(start, path) == end
return path
# print(f"{heuristic_list(end)=} {''.join(path)=} {instructions(path)=}")
# return
while q and max_steps > 0:
max_steps -= 1
if max_steps % 1000 == 0:
print(".", end="", flush=True)
parent2 = q2.popleft()
for twist in lower_anti_axes:
child = apply_opposite(parent2, twist)
if child in skewb_to_twist2:
continue
skewb_to_twist2[child] = twist
if child in skewb_to_twist:
return on_meet(child)
q2.append(child)
parent = q.popleft()
for twist in lower_anti_axes:
child = apply_twist(parent, twist)
if child in skewb_to_twist:
continue
skewb_to_twist[child] = twist
if child in skewb_to_twist2:
return on_meet(child)
q.append(child)
print(f"{len(skewb_to_twist)}")
return None
def heuristic(got: Skewb) -> float:
return sum(heuristic_list(got))
def heuristic_list(got: Skewb) -> list[float]:
out = []
want = solved_skewb
out.append(10000 * (got.top[0] == want.top[0]))
out.append(1000 * (got.bot[3] == want.bot[3]))
out.append(1000 * (got.mids[3] == want.mids[3]))
out.append(100 * (got.top[1] == want.top[1]))
out.append(100 * (got.mids[0] == want.mids[0]))
out.append(100 * (got.bot[0] == want.bot[0]))
out.append(100 * (got.mids[4] == want.mids[4]))
for c_got, c_want in zip(got.top + got.bot, want.top + want.bot):
out.append(1 * (c_got.col == c_want.col) + 2 * (c_got.rot == c_want.rot))
for m_got, m_want in zip(got.mids, want.mids):
out.append(1 * (m_got.rot == m_want.rot) + 4 * (m_got.col == m_want.col))
return out
def get_heuristic_matches(s: Skewb) -> List[int]:
"""
format:
[0:4] top
[4:8] bot
[8:13] mids
"""
out = [
int((c_got.col == c_want.col) and (c_got.rot == c_want.rot))
for c_got, c_want in zip(s.top + s.bot, solved_skewb.top + solved_skewb.bot)
]
out.extend(
int((m_got.rot == m_want.rot) and (m_got.col == m_want.col))
for m_got, m_want in zip(s.mids, solved_skewb.mids)
)
return out
def element_multiply(a: list[int], b: list[int]) -> list[int]:
assert len(a) == len(b)
return [x * y for x, y in zip(a, b)]
def test_heuristic():
assert heuristic(desk_start) < heuristic(solved_skewb)
def random_skewb(seed: int = 4, twists: int = 20) -> Skewb:
return apply_twists(solved_skewb, random_skewb_twists(seed, twists))
def random_skewb_twists(seed: int = 4, twists: int = 20) -> list[LowerAntiAxis]:
out: list[LowerAntiAxis] = []
rng = random.Random(seed)
ax_i = rng.randint(0, 3)
for _ in range(twists):
twist = axes[ax_i]
if rng.getrandbits(1):
twist = to_opposite[twist]
out.append(twist)
ax_i = (ax_i + rng.randint(1, 3)) % len(axes)
return out
def double_clockwise_to_anticlockwise(twists: list[Axis]) -> list[LowerAntiAxis]:
i = 0
n = len(twists)
out: list[LowerAntiAxis] = []
while i < n:
twist = twists[i]
if i < n - 1 and twist == twists[i + 1]:
out.append(to_anticlockwise[twist])
i += 2
else:
out.append(twist)
i += 1
return out
@pytest.mark.parametrize(
"twists, want",
[
("RBOG", "RBOG"),
("RRBBOOGG", "rbog"),
("RRR", "rR"),
("RROR", "rOR"),
("RORR", "ROr"),
],
)
def test_double_clockwise_to_anticlockwise(
twists: list[Axis], want: list[LowerAntiAxis]
):
assert double_clockwise_to_anticlockwise(twists) == list(want)
def test_random_skewb():
twist_count = 50
twists = random_skewb_twists(twists=twist_count)
assert len(twists) == twist_count
def shelve_it(file_name):
d = shelve.open(file_name)
def decorator(func):
def new_func(*args, **kwargs):
key = str(args) + str(kwargs)
if key not in d:
d[key] = func(*args, **kwargs)
return d[key]
return new_func
return decorator
def get_paths_from_heuristic(
start: Skewb, heuristic_permutation: list[int]
) -> list[list[LowerAntiAxis]]:
out: list[list[LowerAntiAxis]] = []
s = start
mask = [0 for _ in heuristic_permutation]
total_path_length = 0
for heuristic_i in heuristic_permutation:
mask[heuristic_i] = 1
def step_finished(candidate: Skewb) -> bool:
matches = get_heuristic_matches(candidate)
assert len(matches) == len(mask)
return all(match >= m for match, m in zip(matches, mask))
if heuristic_i == len(heuristic_permutation) - 3:
# print("going bidirectional now.")
path = bidirectional_search(s, max_steps=200000)
else:
path = breadth_first_search(s, step_finished)
# print()
print(f"{mask=} {s=} {path=}")
if path is None:
raise ValueError("oh no! solver could not find solution")
out.append(path)
s = apply_twists(s, path)
total_path_length += len(path)
return out
# def get_total_path_length(start: Skewb, heuristic_permutation: list[int]) -> int:
# return sum(len(p) for p in get_paths_from_heuristic(start, heuristic_permutation))
# close_to_wrongly_solved = Skewb(top=(R0, B0, O0, G0), bot=(B0, O0, G0, R0), mids=(BY, GRB, ORG, RY, YY))
near_end = Skewb(top=(O0, B0, R0, G2), bot=(B0, R0, G1, O0), mids=(BY, RY, GY, OY, YY))
start = near_end
# start = Skewb(top=(O0, B0, R1, G1), bot=(B0, R2, G2, O0), mids=(BY, RY, GY, OY, YY))
HURISTIC_PERMUTATION_LENGTH = 4 + 4 + 5
def quadratic_mean(values: list[float]) -> float:
return sqrt(sum(x * x for x in values) / len(values))
def get_mean_path_length(
start: Skewb,
heuristic_permutation: list[int],
mean_fn: Callable[[list[float]], float] = quadratic_mean,
) -> float:
return mean_fn(
[len(p) for p in get_paths_from_heuristic(start, heuristic_permutation)]
)
@shelve_it("skewb_solver.evaluate_permutation.shelve.sqlite")
def evaluate_permutation(
heuristic_permutation: list[int], seed=4, sample_size: int = 10
) -> float:
pls = []
for i in range(sample_size):
pl = get_mean_path_length(random_skewb(seed=i + seed), heuristic_permutation)
print(f"{pl=}")
pls.append(pl)
return sum(pls) / len(pls)
def evaluate_all_1_swaps(hp: list[int]):
for i in range(len(hp)):
for j in range(i):
hp[i], hp[j] = hp[j], hp[i]
evaluate_permutation(hp, sample_size=22)
hp[i], hp[j] = hp[j], hp[i]
def evaluate_all_1_swaps_except_first(hp: list[int]):
for i in range(len(hp)):
for j in range(1, i):
hp[i], hp[j] = hp[j], hp[i]
evaluation = evaluate_permutation(hp, sample_size=100)
print(f"{hp=} {evaluation=}")
hp[i], hp[j] = hp[j], hp[i]
if __name__ == "__main__":
# print(bidirectional_search(start, max_steps=20000000, end=solved_skewb))
# print(breadth_first_search(start=start, is_end=lambda s: s == solved_skewb))
# hp = list(range(HURISTIC_PERMUTATION_LENGTH))
hp = top_down = [0, 1, 2, 3, 8, 9, 10, 11, 12, 4, 5, 6, 7]
# hp = small_corner_start = [0, 8, 11, 1, 2, 3, 9, 10, 12, 4, 5, 6, 7]
# hp = big_corner_start = [4, 0, 8, 11, 5, 7, 1, 2, 3, 9, 10, 12, 6]
# print(f"{evaluate_permutation(hp, sample_size=22)=}")
# evaluate_all_1_swaps(top_down)
# print(
# f"{(hp := [10, 1, 2, 3, 8, 9, 0, 11, 12, 4, 5, 6, 7])} {evaluate_permutation(hp, sample_size=100)=}"
# )
# print(
# f"{(hp := [0, 10, 2, 3, 8, 9, 1, 11, 12, 4, 5, 6, 7])} {evaluate_permutation(hp, sample_size=100)=}"
# )
# print(
# f"{(hp := [0, 1, 10, 3, 8, 9, 2, 11, 12, 4, 5, 6, 7])} {evaluate_permutation(hp, sample_size=100)=}"
# )
# print(
# f"{(hp := [0, 1, 2, 10, 8, 9, 3, 11, 12, 4, 5, 6, 7])} {evaluate_permutation(hp, sample_size=100)=}"
# )
# print(
# f"{(hp := [0, 1, 2, 3, 10, 9, 8, 11, 12, 4, 5, 6, 7])} {evaluate_permutation(hp, sample_size=100)=}"
# )
# hp = [10, 1, 2, 3, 8, 9, 0, 11, 12, 4, 5, 6, 7]
# print(f"{hp=} {evaluate_permutation(hp, seed=200, sample_size=200)=}")
# print(f"{hp=} {evaluate_permutation(hp, seed=200, sample_size=200)=}")
# evaluate_all_1_swaps_except_first(hp)
# hp = [10, 11, 2, 3, 8, 9, 0, 1, 12, 4, 5, 6, 7]
# print(f"{hp=} {evaluate_permutation(hp, seed=200, sample_size=200)=}")
# hp = [10, 11, 2, 3, 8, 9, 0, 1, 12, 4, 5, 6, 7]
# print(f"{hp=} {evaluate_permutation(hp, seed=200, sample_size=200)=}")
hp = [10, 0, 1, 2, 3, 8, 9, 11, 12, 4, 5, 6, 7]
path = get_paths_from_heuristic(
Skewb(top=(O0, B0, R0, G0), bot=(B2, R0, G1, O0), mids=(BY, YY, GY, RY, ORB)),
hp,
)
print(path)
# RBOG