python/skewb_solver.py
DomNomNomVR f707876581 skerwb
2025-12-20 01:47:30 +13:00

737 lines
22 KiB
Python

from multiprocessing import Pool, TimeoutError
from math import sqrt
import shelve
from collections import deque
from dataclasses import dataclass, replace
from functools import reduce
import random
from typing import (
TYPE_CHECKING,
Callable,
Counter,
Dict,
Iterable,
Iterator,
List,
Literal,
Set,
Tuple,
)
from functools import lru_cache
import pytest
# print()
# twisting the bottom, opposing this color clockwise when this color is facing away
# RBOG
Axis = Literal["G", "O", "B", "R"]
AXES: tuple[Axis, ...] = Axis.__args__
# How many clockwise twists along an axis has been twisted after being flush with the top or bottom.
CornerRotState = Literal[0, 1, 2]
Twist = Literal["G", "O", "B", "R", "g", "o", "b", "r"]
TWISTS: tuple[Twist, ...] = Twist.__args__
to_anticlockwise: dict[Axis, Twist] = {"G": "g", "O": "o", "B": "b", "R": "r"}
to_clockwise: dict[Twist, Axis] = {v: k for k, v in to_anticlockwise.items()}
to_opposite: dict[Twist, Twist] = {**to_anticlockwise, **to_clockwise}
@dataclass(frozen=True)
class Corner:
col: Axis
rot: CornerRotState
def __repr__(self) -> str:
return f"{self.col}{self.rot}"
MidRotState = Literal[
"Y", "O", "G"
] # which direction it's pointing in (target middle solved colour)
@dataclass(frozen=True)
class Middle:
col: Literal["G", "O", "B", "R", "Y"]
rot: MidRotState
def __repr__(self) -> str:
return f"{self.col}{self.rot}"
@dataclass(frozen=True)
class Skewb:
"""Represents a Rubics cube variant sold as 'Qiyi Twisty Skewb'."""
# order: screw down-right from start
top: Tuple[Corner, Corner, Corner, Corner]
bot: Tuple[Corner, Corner, Corner, Corner]
mids: Tuple[Middle, Middle, Middle, Middle, Middle]
# def __post_init__(self):
# """Check basic constraints such as the correct number of colors"""
def assert_valid(self):
assert Counter(
corner.col for corners in [self.top, self.bot] for corner in corners
) == {col: 2 for col in AXES}
assert Counter(mid.col for mid in self.mids) == Counter("RBOGY")
# forbidden rotations
assert self.mids[0].rot != "R"
assert self.mids[1].rot != "O"
assert self.mids[2].rot != "R"
assert self.mids[3].rot != "O"
assert self.mids[4].col == "Y" or self.mids[4].rot != "Y"
G0 = Corner("G", 0)
G1 = Corner("G", 1)
G2 = Corner("G", 2)
O0 = Corner("O", 0)
O1 = Corner("O", 1)
O2 = Corner("O", 2)
B0 = Corner("B", 0)
B1 = Corner("B", 1)
B2 = Corner("B", 2)
R0 = Corner("R", 0)
R1 = Corner("R", 1)
R2 = Corner("R", 2)
GG = Middle("G", "G")
OG = Middle("O", "G")
BG = Middle("B", "G")
RG = Middle("R", "G")
GO = Middle("G", "O")
OO = Middle("O", "O")
BO = Middle("B", "O")
RO = Middle("R", "O")
GY = Middle("G", "Y")
OY = Middle("O", "Y")
BY = Middle("B", "Y")
RY = Middle("R", "Y")
YY = Middle("Y", "Y")
def rotate_mid_about_W(m: Middle) -> Middle:
return Middle(
col=m.col, rot="Y" if m.rot == "Y" else ("G" if m.rot == "O" else "O")
)
def rotate_everything_about_W(s: Skewb) -> Skewb:
"""Clockwise rotation when looking down upon white."""
return Skewb(
top=(s.top[-1],) + s.top[:-1],
bot=(s.bot[-1],) + s.bot[:-1],
mids=(
rotate_mid_about_W(s.mids[3]),
rotate_mid_about_W(s.mids[0]),
rotate_mid_about_W(s.mids[1]),
rotate_mid_about_W(s.mids[2]),
rotate_mid_about_W(s.mids[4]),
),
)
# 0 1 2 3
# 3 0 1 2
solved_skewb = Skewb(
top=(G0, O0, B0, R0),
bot=(G0, O0, B0, R0),
mids=(GY, OY, BY, RY, YY),
)
@pytest.mark.parametrize("s", [solved_skewb])
def test_rotate_everything_about_W(s: Skewb):
ss = [s]
for i in range(4):
ss.append(rotate_everything_about_W(ss[-1]))
assert ss[0] == ss[-1]
assert all(ss[0] != q for q in ss[1:-1])
def test_axes():
assert AXES[0] == "G"
type CornerRotPermutation = dict[CornerRotState, CornerRotState]
BOT_LEFT_TO_TOP: CornerRotPermutation = {0: 2, 1: 0, 2: 1}
TOP_TO_BOT_RIGHT: CornerRotPermutation = {0: 2, 1: 0, 2: 1}
ROTATE_CORNER_CLOCKWISE: CornerRotPermutation = {0: 1, 1: 2, 2: 0}
BOT_RIGHT_TO_BOT_LEFT: CornerRotPermutation = {0: 2, 1: 0, 2: 1}
@pytest.mark.parametrize(
"p",
[BOT_LEFT_TO_TOP, TOP_TO_BOT_RIGHT, ROTATE_CORNER_CLOCKWISE, BOT_RIGHT_TO_BOT_LEFT],
)
def test_rotation_permutations(p: CornerRotPermutation):
assert set(p.keys()) == {0, 1, 2}
assert set(p.values()) == {0, 1, 2}
MID_DIR_INCREMENT: dict[MidRotState, MidRotState] = {"Y": "O", "G": "Y", "O": "G"}
def clockwise_twist(s: Skewb, axis: Axis) -> Skewb:
"""Applies a clockwise twist to the axis.
Coordinate frame is define as follows:
- axis goes through the `top` of a solved cube of a given color
- the axis-top is on top and faces away from the observer twisting it,
- clockwise is twisting the bottom when looking up from diagonally-below the skewb
"""
rot_before, rot_after = {"G": (0, 0), "O": (3, 1), "B": (2, 2), "R": (1, 3)}[axis]
for _ in range(rot_before):
s = rotate_everything_about_W(s)
s = Skewb(
top=(
s.top[0],
s.top[1],
Corner((c := s.bot[0]).col, BOT_LEFT_TO_TOP[c.rot]),
s.top[3],
),
bot=(
Corner((c := s.bot[2]).col, BOT_RIGHT_TO_BOT_LEFT[c.rot]),
s.bot[1],
Corner((c := s.top[2]).col, TOP_TO_BOT_RIGHT[c.rot]),
Corner((c := s.bot[3]).col, ROTATE_CORNER_CLOCKWISE[c.rot]),
),
mids=(
s.mids[0],
s.mids[1],
Middle(
(m := s.mids[3]).col,
"Y" if m.col == "Y" else MID_DIR_INCREMENT[m.rot],
),
Middle(
(m := s.mids[4]).col,
"Y" if m.col == "Y" else MID_DIR_INCREMENT[m.rot],
),
Middle(
(m := s.mids[2]).col,
"Y" if m.col == "Y" else ("O" if m.rot == "Y" else "G"),
),
),
)
for _ in range(rot_after):
s = rotate_everything_about_W(s)
return s
@pytest.mark.parametrize(
"axis, want",
[
(
"O",
Skewb(
top=(G0, O0, B0, O2), bot=(G1, R2, B0, R2), mids=(YY, OY, BY, GG, RG)
),
)
],
)
def test_clockwise_twist_from_solved(axis: Axis, want: Skewb):
want.assert_valid()
assert clockwise_twist(solved_skewb, axis) == want
def anticlockwise_twist(s: Skewb, twist: Axis) -> Skewb:
return clockwise_twist(clockwise_twist(s, twist), twist)
@pytest.mark.parametrize("start", [solved_skewb])
@pytest.mark.parametrize("axis", AXES)
def test_clockwise_twist(start: Skewb, axis: Axis):
assert anticlockwise_twist(clockwise_twist(start, axis), axis) == start
assert clockwise_twist(anticlockwise_twist(start, axis), axis) == start
ss = [start]
for i in range(3):
ss.append(clockwise_twist(ss[-1], axis))
assert len(ss) == 4
assert ss[0] == ss[-1]
assert all([ss[0] != q for q in ss[1:-1]])
for s in ss:
assert anticlockwise_twist(clockwise_twist(s, axis), axis) == s
assert clockwise_twist(anticlockwise_twist(s, axis), axis) == s
def apply_twist(start: Skewb, twist: Twist) -> Skewb:
if twist in AXES:
return clockwise_twist(start, twist)
return anticlockwise_twist(start, to_clockwise[twist])
def apply_opposite(s: Skewb, twist: Twist) -> Skewb:
return apply_twist(s, to_opposite[twist])
def apply_twists(start: Skewb, twists: list[Twist]) -> Skewb:
return reduce(apply_twist, twists, start)
def instructions(twists: list[Axis]) -> str:
out = ""
axis = "R"
for twist in twists:
while axis != twist:
axis = AXES[(1 + AXES.index(axis)) % 4]
out += "L"
out += "."
out = out.replace("..", ":")
out = out.replace("LLL", "J")
return out
def breadth_first_search(
start: Skewb,
is_end: Callable[[Skewb], bool],
max_steps: int = 2000000,
bidirectional_fallback_threshold: int | None = None,
) -> list[Twist] | None:
start.assert_valid()
if is_end(start):
return []
q = deque([start])
# what action got us to this point
skewb_to_twist: dict[Skewb, Twist | None] = {skewb: None for skewb in q}
def get_path(end: Skewb) -> list[Twist]:
out = []
s = end
while twist := skewb_to_twist[s]:
out.append(twist)
s = apply_opposite(s, twist)
out.reverse()
return out
step_count = 0
while q and max_steps > 0:
if step_count % 1000 == 0 and step_count:
# print(f".", end="", flush=True)
if (
bidirectional_fallback_threshold is not None
and step_count > bidirectional_fallback_threshold
):
print("bidirectional", end=" ")
return bidirectional_search(start, max_steps)
step_count += 1
parent = q.popleft()
for twist in TWISTS:
child = apply_twist(parent, twist)
if child in skewb_to_twist:
continue
skewb_to_twist[child] = twist
if is_end(child):
return get_path(child)
q.append(child)
return None
def test_breadth_first_search():
for twist in TWISTS:
assert breadth_first_search(
apply_opposite(solved_skewb, twist), is_end=lambda s: s == solved_skewb
) == [twist]
def bidirectional_search(
start: Skewb, max_steps: int, end: Skewb = solved_skewb
) -> list[Twist] | None:
try:
return next(bidirectional_search_all(start, max_steps, end))
except StopIteration:
return None
def bidirectional_search_all(
start: Skewb, max_steps: int, end: Skewb = solved_skewb
) -> Iterator[list[Twist]]:
if start == end:
yield []
return
start.assert_valid()
q = deque([start])
q2 = deque([end])
# what action got us to this point
skewb_to_twist: dict[Skewb, Twist | None] = {skewb: None for skewb in q}
skewb_to_twist2: dict[Skewb, Twist | None] = {skewb: None for skewb in q2}
def get_path(meet: Skewb) -> list[Twist]:
path = []
s = meet
while twist := skewb_to_twist[s]:
path.append(twist)
s = apply_opposite(s, twist)
path.reverse()
s = meet
while twist := skewb_to_twist2[s]:
path.append(twist)
s = apply_twist(s, twist)
return path
def instructions(twists: list[Axis]) -> str:
out = ""
axis = "R"
for twist in twists:
while axis != twist:
axis = AXES[(1 + AXES.index(axis)) % 4]
out += "L"
out += "."
out = out.replace("..", ":")
out = out.replace("LLL", "J")
return out
# def on_meet(meet: Skewb) -> list[Twist]:
# path = get_path(meet)
# assert apply_twists(start, path) == end
# return path
# # print(f"{heuristic_list(end)=} {''.join(path)=} {instructions(path)=}")
# # return
last_path = None
seen_paths: set[tuple[Twist, ...]] = set()
while q and max_steps > 0:
max_steps -= 1
if max_steps % 1000 == 0:
print(".", end="", flush=True)
parent2 = q2.popleft()
for twist in TWISTS:
child = apply_opposite(parent2, twist)
if child in skewb_to_twist2:
continue
skewb_to_twist2[child] = twist
if child in skewb_to_twist:
# DUPLICATED
path = get_path(child)
assert apply_twists(start, path) == end
# if last_path is not None and len(path) > len(last_path):
# return
# last_path = path
if tuple(path) not in seen_paths and (
last_path is None or len(path) <= len(last_path)
):
yield path
seen_paths.add(tuple(path))
last_path = path
q2.append(child)
parent = q.popleft()
for twist in TWISTS:
child = apply_twist(parent, twist)
if child in skewb_to_twist:
continue
skewb_to_twist[child] = twist
if child in skewb_to_twist2:
# DUPLICATED
path = get_path(child)
assert apply_twists(start, path) == end
# if last_path is not None and len(path) > len(last_path):
# return
# last_path = path
if tuple(path) not in seen_paths and (
last_path is None or len(path) <= len(last_path)
):
yield path
seen_paths.add(tuple(path))
last_path = path
q.append(child)
print(f"{len(skewb_to_twist)}")
def heuristic(got: Skewb) -> float:
return sum(heuristic_list(got))
def heuristic_list(got: Skewb) -> list[float]:
out = []
want = solved_skewb
out.append(10000 * (got.top[0] == want.top[0]))
out.append(1000 * (got.bot[3] == want.bot[3]))
out.append(1000 * (got.mids[3] == want.mids[3]))
out.append(100 * (got.top[1] == want.top[1]))
out.append(100 * (got.mids[0] == want.mids[0]))
out.append(100 * (got.bot[0] == want.bot[0]))
out.append(100 * (got.mids[4] == want.mids[4]))
for c_got, c_want in zip(got.top + got.bot, want.top + want.bot):
out.append(1 * (c_got.col == c_want.col) + 2 * (c_got.rot == c_want.rot))
for m_got, m_want in zip(got.mids, want.mids):
out.append(1 * (m_got.rot == m_want.rot) + 4 * (m_got.col == m_want.col))
return out
def get_heuristic_matches(s: Skewb) -> List[int]:
"""
format:
[0:4] top
[4:8] bot
[8:13] mids
"""
out = [
int((c_got.col == c_want.col) and (c_got.rot == c_want.rot))
for c_got, c_want in zip(s.top + s.bot, solved_skewb.top + solved_skewb.bot)
]
out.extend(
int((m_got.rot == m_want.rot) and (m_got.col == m_want.col))
for m_got, m_want in zip(s.mids, solved_skewb.mids)
)
return out
def element_multiply(a: list[int], b: list[int]) -> list[int]:
assert len(a) == len(b)
return [x * y for x, y in zip(a, b)]
def test_heuristic():
assert heuristic(apply_twist(solved_skewb, "O")) < heuristic(solved_skewb)
def random_skewb(seed: int = 4, twists: int = 20) -> Skewb:
return apply_twists(solved_skewb, random_skewb_twists(seed, twists))
def random_skewb_twists(seed: int = 4, twists: int = 20) -> list[Twist]:
out: list[Twist] = []
rng = random.Random(seed)
ax_i = rng.randint(0, 3)
for _ in range(twists):
twist = AXES[ax_i]
if rng.getrandbits(1):
twist = to_opposite[twist]
out.append(twist)
ax_i = (ax_i + rng.randint(1, 3)) % len(AXES)
return out
def double_clockwise_to_anticlockwise(twists: list[Axis]) -> list[Twist]:
i = 0
n = len(twists)
out: list[Twist] = []
while i < n:
twist = twists[i]
if i < n - 1 and twist == twists[i + 1]:
out.append(to_anticlockwise[twist])
i += 2
else:
out.append(twist)
i += 1
return out
@pytest.mark.parametrize(
"twists, want",
[
("RBOG", "RBOG"),
("RRBBOOGG", "rbog"),
("RRR", "rR"),
("RROR", "rOR"),
("RORR", "ROr"),
],
)
def test_double_clockwise_to_anticlockwise(twists: list[Axis], want: list[Twist]):
assert double_clockwise_to_anticlockwise(twists) == list(want)
def test_random_skewb():
twist_count = 50
twists = random_skewb_twists(twists=twist_count)
assert len(twists) == twist_count
def shelve_it(file_name):
d = shelve.open(file_name)
def decorator(func):
def new_func(*args, **kwargs):
key = str(args) + str(kwargs)
if key not in d:
d[key] = func(*args, **kwargs)
return d[key]
return new_func
return decorator
def get_paths_from_heuristic(
start: Skewb, heuristic_permutation: list[int]
) -> list[list[Twist]]:
out: list[list[Twist]] = []
s = start
mask = [0 for _ in heuristic_permutation]
total_path_length = 0
for heuristic_i in heuristic_permutation:
mask[heuristic_i] = 1
def step_finished(candidate: Skewb) -> bool:
matches = get_heuristic_matches(candidate)
assert len(matches) == len(mask)
return all(match >= m for match, m in zip(matches, mask))
# print(f"{mask=} {s=}", end=" ")
if sum(mask) == len(heuristic_permutation) - 3:
# print("going bidirectional now.")
# print("end-bi", end=" ", flush=True)
path = bidirectional_search(s, max_steps=200000)
else:
path = breadth_first_search(
s, step_finished, bidirectional_fallback_threshold=200000
)
# print(f" {path=}")
if path is None:
raise ValueError("oh no! solver could not find solution")
out.append(path)
s = apply_twists(s, path)
total_path_length += len(path)
return out
# def get_total_path_length(start: Skewb, heuristic_permutation: list[int]) -> int:
# return sum(len(p) for p in get_paths_from_heuristic(start, heuristic_permutation))
# close_to_wrongly_solved = Skewb(top=(R0, B0, O0, G0), bot=(B0, O0, G0, R0), mids=(BY, GRB, ORG, RY, YY))
# start = Skewb(top=(O0, B0, R1, G1), bot=(B0, R2, G2, O0), mids=(BY, RY, GY, OY, YY))
HURISTIC_PERMUTATION_LENGTH = 4 + 4 + 5
def quadratic_mean(values: list[float]) -> float:
return sqrt(sum(x * x for x in values) / len(values))
def score_fn(values: list[float]) -> float:
return -sum(2**i * v for i, v in enumerate(sorted(values)))
def parallel_task(seed: int, heuristic_permutation: list[int]):
paths = get_paths_from_heuristic(random_skewb(seed), heuristic_permutation)
return score_fn([len(p) for p in paths])
@shelve_it("skewb_solver.evaluate_permutation.shelve.sqlite")
def evaluate_permutation(
heuristic_permutation: list[int], seed=4, sample_size: int = 10
) -> float:
scores = []
rng = random.Random(seed)
args = [(rng.randint(0, 2**63), heuristic_permutation) for i in range(sample_size)]
with Pool(processes=32) as pool:
for score in pool.starmap(parallel_task, args):
print(f"{heuristic_permutation=} {score=}", flush=True)
scores.append(score)
return sum(scores) / len(scores)
def evaluate_all_1_swaps(hp: list[int]):
for i in range(len(hp)):
for j in range(i):
hp[i], hp[j] = hp[j], hp[i]
evaluate_permutation(hp, sample_size=22)
hp[i], hp[j] = hp[j], hp[i]
def evaluate_all_1_swaps_except_first(hp: list[int], sample_size: int):
for i in range(len(hp)):
for j in range(1, i):
hp[i], hp[j] = hp[j], hp[i]
evaluation = evaluate_permutation(hp, sample_size=sample_size)
print(f"{hp=} {evaluation=}")
hp[i], hp[j] = hp[j], hp[i]
def print_path_detailed(s: Skewb, paths: list[list[Twist]]):
for path in paths:
print(s, path)
s = apply_twists(s, path)
def reverse_path(path: List[Twist]) -> List[Twist]:
return [to_opposite[t] for t in reversed(path)]
if __name__ == "__main__":
# hp = top_bot_mids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
# hp = top_down_modified = [0, 1, 2, 3, 8, 9, 10, 11, 12, 4, 5, 6, 7]
# hp = top_down = [0, 1, 2, 3, 8, 9, 10, 11, 12, 4, 5, 6, 7]
# hp = local_minimum = [0, 1, 6, 3, 8, 9, 10, 11, 12, 4, 5, 2, 7]
# hp = second_1_swap_min = [0, 3, 6, 1, 8, 9, 10, 11, 12, 4, 5, 2, 7]
# good_hps = [
# [0, 1, 12, 3, 8, 9, 10, 11, 6, 4, 5, 2, 7],
# [0, 1, 3, 6, 8, 9, 10, 11, 12, 4, 5, 2, 7],
# [0, 1, 6, 3, 11, 9, 10, 8, 12, 4, 5, 2, 7],
# [0, 1, 6, 8, 3, 9, 10, 11, 12, 4, 5, 2, 7],
# [0, 1, 6, 9, 8, 3, 10, 11, 12, 4, 5, 2, 7],
# [0, 1, 8, 3, 6, 9, 10, 11, 12, 4, 5, 2, 7],
# [0, 3, 6, 1, 8, 9, 10, 11, 12, 4, 5, 2, 7],
# [0, 5, 6, 3, 8, 9, 10, 11, 12, 4, 1, 2, 7],
# [0, 8, 6, 3, 1, 9, 10, 11, 12, 4, 5, 2, 7],
# [0, 10, 6, 3, 8, 9, 1, 11, 12, 4, 5, 2, 7],
# ]
# hp = second_1_swap_min_200 = [0, 3, 6, 1, 8, 9, 10, 11, 12, 4, 5, 2, 7]
# start = Skewb(
# (G0, O0, G2, R0),
# (B2, R1, B1, O0),
# (GY, OY, RO, YY, BO),
# )
# print_path_detailed(start, get_paths_from_heuristic(start, hp))
# start = Skewb(
# (G0, O0, G2, R0),
# (B2, R1, B1, O0),
# (GY, OY, RO, YY, BO),
# )
start = Skewb( # midswap
top=(G0, R0, B0, O0),
mids=(YY, RY, BY, OY, GO),
bot=(G0, R0, B0, O0),
)
# start = Skewb(
# top=(G0, O0, B1, R0),
# mids=(YY, OY, GY, RY, BO),
# bot=(B2, O2, G0, R0),
# )
print(bidirectional_search(start, 100000, solved_skewb))
print(reverse_path(bidirectional_search(start, 100000, solved_skewb) or []))
print()
for path in bidirectional_search_all(start, 100000, solved_skewb):
print(path)
# for hp in good_hps:
# evaluate_permutation(hp, seed=64, sample_size=200)
# evaluate_all_1_swaps(top_down)
# evaluate_all_1_swaps_except_first(hp, 100)
# evaluate_all_1_swaps_except_first(hp, sample_size=10)
# start = hard = Skewb((O0, B0, R0, G2), (B0, R0, G1, O0), (BY, RY, GY, OY, YY))
# start = subtle_invalid = Skewb((O0, B0, R0, G2), (B0, R0, G1, O0), (BY, RY, GY, OY, YY))
# start = bed = Skewb((G0, O0, B0, R0), (G0, R1, B2, O2), (BY, OY, RO, GR, YY))
# print(bidirectional_search(start, max_steps=1000000))
# print(get_paths_from_heuristic(start, hp))