python/skewb_solver.py
2025-04-15 10:46:18 +12:00

622 lines
18 KiB
Python

from math import sqrt
import shelve
from collections import deque
from dataclasses import dataclass, replace
from functools import reduce
import random
from typing import TYPE_CHECKING, Callable, Counter, Dict, List, Literal, Set, Tuple
from functools import lru_cache
import pytest
print()
# twisting the bottom, opposing this color clockwise when this color is facing away
# RBOG
Axis = Literal["G", "O", "B", "R"]
AXES: tuple[Axis, ...] = Axis.__args__
# How many clockwise twists along an axis has been twisted after being flush with the top or bottom.
CornerRotState = Literal[0, 1, 2]
Twist = Literal["G", "O", "B", "R", "g", "o", "b", "r"]
TWISTS: tuple[Twist, ...] = Twist.__args__
to_anticlockwise: dict[Axis, Twist] = {"G": "g", "O": "o", "B": "b", "R": "r"}
to_clockwise: dict[Twist, Axis] = {v: k for k, v in to_anticlockwise.items()}
to_opposite: dict[Twist, Twist] = {**to_anticlockwise, **to_clockwise}
@dataclass(frozen=True)
class Corner:
col: Axis
rot: CornerRotState
def __repr__(self) -> str:
return f"{self.col}{self.rot}"
MidRotState = Literal["Y", "O", "R"]
@dataclass(frozen=True)
class Middle:
col: Literal["G", "O", "B", "R", "Y"]
rot: MidRotState
def __repr__(self) -> str:
return f"{self.col}{self.rot}"
@dataclass(frozen=True)
class Skewb:
"""Represents a Rubics cube variant sold as 'Qiyi Twisty Skewb'."""
# order: vertically below Axis primary color
top: Tuple[Corner, Corner, Corner, Corner]
bot: Tuple[Corner, Corner, Corner, Corner]
mids: Tuple[Middle, Middle, Middle, Middle, Middle]
# def __post_init__(self):
# """Check basic constraints such as the correct number of colors"""
def assert_valid(self):
assert Counter(
corner.col for corners in [self.top, self.bot] for corner in corners
) == {col: 2 for col in AXES}
assert Counter(mid.col for mid in self.mids) == Counter("RBOGY")
# forbidden rotations
assert self.mids[0].rot != "R"
assert self.mids[1].rot != "O"
assert self.mids[2].rot != "R"
assert self.mids[3].rot != "O"
assert self.mids[4].col == "Y" or self.mids[4].rot != "Y"
G0 = Corner("G", 0)
G1 = Corner("G", 1)
G2 = Corner("G", 2)
O0 = Corner("O", 0)
O1 = Corner("O", 1)
O2 = Corner("O", 2)
B0 = Corner("B", 0)
B1 = Corner("B", 1)
B2 = Corner("B", 2)
R0 = Corner("R", 0)
R1 = Corner("R", 1)
R2 = Corner("R", 2)
GR = Middle("G", "R")
OR = Middle("O", "R")
BR = Middle("B", "R")
RR = Middle("R", "R")
GO = Middle("G", "O")
OO = Middle("O", "O")
BO = Middle("B", "O")
RO = Middle("R", "O")
GY = Middle("G", "Y")
OY = Middle("O", "Y")
BY = Middle("B", "Y")
RY = Middle("R", "Y")
YY = Middle("Y", "Y")
def rotate_mid_about_W(m: Middle) -> Middle:
return Middle(
col=m.col, rot="Y" if m.rot == "Y" else ("R" if m.rot == "O" else "O")
)
def rotate_everything_about_W(s: Skewb) -> Skewb:
"""Clockwise rotation when looking down upon white."""
return Skewb(
top=(s.top[-1],) + s.top[:-1],
bot=(s.bot[-1],) + s.bot[:-1],
mids=(
rotate_mid_about_W(s.mids[3]),
rotate_mid_about_W(s.mids[0]),
rotate_mid_about_W(s.mids[1]),
rotate_mid_about_W(s.mids[2]),
rotate_mid_about_W(s.mids[4]),
),
)
solved_skewb = Skewb(
top=(G0, O0, B0, R0),
bot=(G0, O0, B0, R0),
mids=(GY, OY, BY, RY, YY),
)
@pytest.mark.parametrize("s", [solved_skewb])
def test_rotate_everything_about_W(s: Skewb):
ss = [s]
for i in range(4):
ss.append(rotate_everything_about_W(ss[-1]))
assert ss[0] == ss[-1]
assert all(ss[0] != q for q in ss[1:-1])
def test_axes():
assert AXES[0] == "G"
type CornerRotPermutation = dict[CornerRotState, CornerRotState]
BOT_LEFT_TO_TOP: CornerRotPermutation = {0: 2, 1: 0, 2: 1}
TOP_TO_BOT_RIGHT: CornerRotPermutation = {0: 2, 1: 0, 2: 1}
ROTATE_CORNER_CLOCKWISE: CornerRotPermutation = {0: 1, 1: 2, 2: 0}
BOT_RIGHT_TO_BOT_LEFT: CornerRotPermutation = {0: 2, 1: 0, 2: 1}
@pytest.mark.parametrize(
"p",
[BOT_LEFT_TO_TOP, TOP_TO_BOT_RIGHT, ROTATE_CORNER_CLOCKWISE, BOT_RIGHT_TO_BOT_LEFT],
)
def test_rotation_permutations(p: CornerRotPermutation):
assert set(p.keys()) == {0, 1, 2}
assert set(p.values()) == {0, 1, 2}
MID_DIR_INCREMENT: dict[MidRotState, MidRotState] = {"Y": "O", "R": "Y", "O": "R"}
def clockwise_twist(s: Skewb, axis: Axis) -> Skewb:
"""Applies a clockwise twist to the axis.
Coordinate frame is define as follows:
- axis goes through the `top` of a solved cube of a given color
- the axis-top is on top and faces away from the observer twisting it,
- clockwise is twisting the bottom when looking up from diagonally-below the skewb
"""
rot_before, rot_after = {"G": (0, 0), "O": (3, 1), "B": (2, 2), "R": (1, 3)}[axis]
for _ in range(rot_before):
s = rotate_everything_about_W(s)
s = Skewb(
top=(
s.top[0],
s.top[1],
Corner((c := s.bot[0]).col, BOT_LEFT_TO_TOP[c.rot]),
s.top[3],
),
bot=(
Corner((c := s.bot[2]).col, BOT_RIGHT_TO_BOT_LEFT[c.rot]),
s.bot[1],
Corner((c := s.top[2]).col, TOP_TO_BOT_RIGHT[c.rot]),
Corner((c := s.bot[3]).col, ROTATE_CORNER_CLOCKWISE[c.rot]),
),
mids=(
s.mids[0],
s.mids[1],
Middle(
(m := s.mids[3]).col,
"Y" if m.col == "Y" else MID_DIR_INCREMENT[m.rot],
),
Middle(
(m := s.mids[4]).col,
"Y" if m.col == "Y" else MID_DIR_INCREMENT[m.rot],
),
Middle(
(m := s.mids[2]).col,
"Y" if m.col == "Y" else ("O" if m.rot == "Y" else "R"),
),
),
)
for _ in range(rot_after):
s = rotate_everything_about_W(s)
return s
@pytest.mark.parametrize(
"axis, want",
[
(
"O",
Skewb(
top=(G0, O0, B0, O2), bot=(G1, R2, B0, R2), mids=(YY, OY, BY, GR, RR)
),
)
],
)
def test_clockwise_twist_from_solved(axis: Axis, want: Skewb):
want.assert_valid()
assert clockwise_twist(solved_skewb, axis) == want
def anticlockwise_twist(s: Skewb, twist: Axis) -> Skewb:
return clockwise_twist(clockwise_twist(s, twist), twist)
@pytest.mark.parametrize("start", [solved_skewb])
@pytest.mark.parametrize("axis", AXES)
def test_clockwise_twist(start: Skewb, axis: Axis):
assert anticlockwise_twist(clockwise_twist(start, axis), axis) == start
assert clockwise_twist(anticlockwise_twist(start, axis), axis) == start
ss = [start]
for i in range(3):
ss.append(clockwise_twist(ss[-1], axis))
assert len(ss) == 4
assert ss[0] == ss[-1]
assert all([ss[0] != q for q in ss[1:-1]])
for s in ss:
assert anticlockwise_twist(clockwise_twist(s, axis), axis) == s
assert clockwise_twist(anticlockwise_twist(s, axis), axis) == s
def apply_twist(start: Skewb, twist: Twist) -> Skewb:
if twist in AXES:
return clockwise_twist(start, twist)
return anticlockwise_twist(start, to_clockwise[twist])
def apply_opposite(s: Skewb, twist: Twist) -> Skewb:
return apply_twist(s, to_opposite[twist])
def apply_twists(start: Skewb, twists: list[Twist]) -> Skewb:
return reduce(apply_twist, twists, start)
def instructions(twists: list[Axis]) -> str:
out = ""
axis = "R"
for twist in twists:
while axis != twist:
axis = AXES[(1 + AXES.index(axis)) % 4]
out += "L"
out += "."
out = out.replace("..", ":")
out = out.replace("LLL", "J")
return out
def breadth_first_search(
start: Skewb,
is_end: Callable[[Skewb], bool],
max_steps: int = 2000000,
bidirectional_fallback_threshold: int | None = None,
) -> list[Twist] | None:
start.assert_valid()
if is_end(start):
return []
q = deque([start])
# what action got us to this point
skewb_to_twist: dict[Skewb, Twist | None] = {skewb: None for skewb in q}
def get_path(end: Skewb) -> list[Twist]:
out = []
s = end
while twist := skewb_to_twist[s]:
out.append(twist)
s = apply_opposite(s, twist)
out.reverse()
return out
step_count = 0
while q and max_steps > 0:
if max_steps % 1000 == 0 and step_count:
print(".", end="", flush=True)
if (
bidirectional_fallback_threshold is not None
and step_count > bidirectional_fallback_threshold
):
return bidirectional_search(start, max_steps)
step_count += 1
parent = q.popleft()
for twist in TWISTS:
child = apply_twist(parent, twist)
if child in skewb_to_twist:
continue
skewb_to_twist[child] = twist
if is_end(child):
return get_path(child)
q.append(child)
return None
def test_breadth_first_search():
for twist in TWISTS:
assert breadth_first_search(
apply_opposite(solved_skewb, twist), is_end=lambda s: s == solved_skewb
) == [twist]
def print_path(path: list[Axis]):
x = start
print(f"S -> {x}")
for twist in path:
x = clockwise_twist(x, twist)
print(f"{twist} -> {x}")
def bidirectional_search(
start: Skewb, max_steps: int, end: Skewb = solved_skewb
) -> list[Twist] | None:
start.assert_valid()
q = deque([start])
q2 = deque([end])
# what action got us to this point
skewb_to_twist: dict[Skewb, Twist | None] = {skewb: None for skewb in q}
skewb_to_twist2: dict[Skewb, Twist | None] = {skewb: None for skewb in q2}
def get_path(meet: Skewb) -> list[Twist]:
path = []
s = meet
while twist := skewb_to_twist[s]:
path.append(twist)
s = apply_opposite(s, twist)
path.reverse()
s = meet
while twist := skewb_to_twist2[s]:
path.append(twist)
s = apply_twist(s, twist)
return path
def instructions(twists: list[Axis]) -> str:
out = ""
axis = "R"
for twist in twists:
while axis != twist:
axis = AXES[(1 + AXES.index(axis)) % 4]
out += "L"
out += "."
out = out.replace("..", ":")
out = out.replace("LLL", "J")
return out
def on_meet(meet: Skewb) -> list[Twist]:
path = get_path(meet)
assert apply_twists(start, path) == end
return path
# print(f"{heuristic_list(end)=} {''.join(path)=} {instructions(path)=}")
# return
while q and max_steps > 0:
max_steps -= 1
if max_steps % 1000 == 0:
print(".", end="", flush=True)
parent2 = q2.popleft()
for twist in TWISTS:
child = apply_opposite(parent2, twist)
if child in skewb_to_twist2:
continue
skewb_to_twist2[child] = twist
if child in skewb_to_twist:
return on_meet(child)
q2.append(child)
parent = q.popleft()
for twist in TWISTS:
child = apply_twist(parent, twist)
if child in skewb_to_twist:
continue
skewb_to_twist[child] = twist
if child in skewb_to_twist2:
return on_meet(child)
q.append(child)
print(f"{len(skewb_to_twist)}")
return None
def heuristic(got: Skewb) -> float:
return sum(heuristic_list(got))
def heuristic_list(got: Skewb) -> list[float]:
out = []
want = solved_skewb
out.append(10000 * (got.top[0] == want.top[0]))
out.append(1000 * (got.bot[3] == want.bot[3]))
out.append(1000 * (got.mids[3] == want.mids[3]))
out.append(100 * (got.top[1] == want.top[1]))
out.append(100 * (got.mids[0] == want.mids[0]))
out.append(100 * (got.bot[0] == want.bot[0]))
out.append(100 * (got.mids[4] == want.mids[4]))
for c_got, c_want in zip(got.top + got.bot, want.top + want.bot):
out.append(1 * (c_got.col == c_want.col) + 2 * (c_got.rot == c_want.rot))
for m_got, m_want in zip(got.mids, want.mids):
out.append(1 * (m_got.rot == m_want.rot) + 4 * (m_got.col == m_want.col))
return out
def get_heuristic_matches(s: Skewb) -> List[int]:
"""
format:
[0:4] top
[4:8] bot
[8:13] mids
"""
out = [
int((c_got.col == c_want.col) and (c_got.rot == c_want.rot))
for c_got, c_want in zip(s.top + s.bot, solved_skewb.top + solved_skewb.bot)
]
out.extend(
int((m_got.rot == m_want.rot) and (m_got.col == m_want.col))
for m_got, m_want in zip(s.mids, solved_skewb.mids)
)
return out
def element_multiply(a: list[int], b: list[int]) -> list[int]:
assert len(a) == len(b)
return [x * y for x, y in zip(a, b)]
def test_heuristic():
assert heuristic(apply_twist(solved_skewb, "O")) < heuristic(solved_skewb)
def random_skewb(seed: int = 4, twists: int = 20) -> Skewb:
return apply_twists(solved_skewb, random_skewb_twists(seed, twists))
def random_skewb_twists(seed: int = 4, twists: int = 20) -> list[Twist]:
out: list[Twist] = []
rng = random.Random(seed)
ax_i = rng.randint(0, 3)
for _ in range(twists):
twist = AXES[ax_i]
if rng.getrandbits(1):
twist = to_opposite[twist]
out.append(twist)
ax_i = (ax_i + rng.randint(1, 3)) % len(AXES)
return out
def double_clockwise_to_anticlockwise(twists: list[Axis]) -> list[Twist]:
i = 0
n = len(twists)
out: list[Twist] = []
while i < n:
twist = twists[i]
if i < n - 1 and twist == twists[i + 1]:
out.append(to_anticlockwise[twist])
i += 2
else:
out.append(twist)
i += 1
return out
@pytest.mark.parametrize(
"twists, want",
[
("RBOG", "RBOG"),
("RRBBOOGG", "rbog"),
("RRR", "rR"),
("RROR", "rOR"),
("RORR", "ROr"),
],
)
def test_double_clockwise_to_anticlockwise(twists: list[Axis], want: list[Twist]):
assert double_clockwise_to_anticlockwise(twists) == list(want)
def test_random_skewb():
twist_count = 50
twists = random_skewb_twists(twists=twist_count)
assert len(twists) == twist_count
def shelve_it(file_name):
d = shelve.open(file_name)
def decorator(func):
def new_func(*args, **kwargs):
key = str(args) + str(kwargs)
if key not in d:
d[key] = func(*args, **kwargs)
return d[key]
return new_func
return decorator
def get_paths_from_heuristic(
start: Skewb, heuristic_permutation: list[int]
) -> list[list[Twist]]:
out: list[list[Twist]] = []
s = start
mask = [0 for _ in heuristic_permutation]
total_path_length = 0
for heuristic_i in heuristic_permutation:
mask[heuristic_i] = 1
def step_finished(candidate: Skewb) -> bool:
matches = get_heuristic_matches(candidate)
assert len(matches) == len(mask)
return all(match >= m for match, m in zip(matches, mask))
print(f"{mask=} {s=}", end=" ")
if heuristic_i == len(heuristic_permutation) - 3:
# print("going bidirectional now.")
path = bidirectional_search(s, max_steps=200000)
else:
path = breadth_first_search(
s, step_finished, bidirectional_fallback_threshold=20000
)
print(f" {path=}")
if path is None:
raise ValueError("oh no! solver could not find solution")
out.append(path)
s = apply_twists(s, path)
total_path_length += len(path)
return out
# def get_total_path_length(start: Skewb, heuristic_permutation: list[int]) -> int:
# return sum(len(p) for p in get_paths_from_heuristic(start, heuristic_permutation))
# close_to_wrongly_solved = Skewb(top=(R0, B0, O0, G0), bot=(B0, O0, G0, R0), mids=(BY, GRB, ORG, RY, YY))
near_end = Skewb(top=(O0, B0, R0, G2), bot=(B0, R0, G1, O0), mids=(BY, RY, GY, OY, YY))
start = near_end
# start = Skewb(top=(O0, B0, R1, G1), bot=(B0, R2, G2, O0), mids=(BY, RY, GY, OY, YY))
HURISTIC_PERMUTATION_LENGTH = 4 + 4 + 5
def quadratic_mean(values: list[float]) -> float:
return sqrt(sum(x * x for x in values) / len(values))
def score_fn(values: list[float]) -> float:
return -sum(2**i * v for i, v in enumerate(sorted(values)))
@shelve_it("skewb_solver.evaluate_permutation.shelve.sqlite")
def evaluate_permutation(
heuristic_permutation: list[int], seed=4, sample_size: int = 10
) -> float:
scores = []
for i in range(sample_size):
paths = get_paths_from_heuristic(start, heuristic_permutation)
score = score_fn([len(p) for p in paths])
print(f"{score=}")
scores.append(score)
return sum(scores) / len(scores)
def evaluate_all_1_swaps(hp: list[int]):
for i in range(len(hp)):
for j in range(i):
hp[i], hp[j] = hp[j], hp[i]
evaluate_permutation(hp, sample_size=22)
hp[i], hp[j] = hp[j], hp[i]
def evaluate_all_1_swaps_except_first(hp: list[int]):
for i in range(len(hp)):
for j in range(1, i):
hp[i], hp[j] = hp[j], hp[i]
evaluation = evaluate_permutation(hp, sample_size=100)
print(f"{hp=} {evaluation=}")
hp[i], hp[j] = hp[j], hp[i]
if __name__ == "__main__":
hp = top_bot_mids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
hp = top_down = [0, 1, 2, 3, 8, 9, 10, 11, 12, 4, 5, 6, 7]
# evaluate_all_1_swaps(top_down)
evaluate_all_1_swaps_except_first(hp)
start = Skewb((G0, O0, B0, R0), (G0, R1, B2, O2), (BY, OY, RO, GR, YY))
# print(bidirectional_search(start, max_steps=1000000))
# print(get_paths_from_heuristic(start, hp))