from math import sqrt import shelve from collections import deque from dataclasses import dataclass, replace from functools import reduce import random from typing import TYPE_CHECKING, Callable, Counter, Dict, List, Literal, Set, Tuple from functools import lru_cache import pytest print() # twisting the bottom, opposing this color clockwise when this color is facing away # RBOG Axis = Literal["G", "O", "B", "R"] AXES: tuple[Axis, ...] = Axis.__args__ # How many clockwise twists along an axis has been twisted after being flush with the top or bottom. CornerRotState = Literal[0, 1, 2] Twist = Literal["G", "O", "B", "R", "g", "o", "b", "r"] TWISTS: tuple[Twist, ...] = Twist.__args__ to_anticlockwise: dict[Axis, Twist] = {"G": "g", "O": "o", "B": "b", "R": "r"} to_clockwise: dict[Twist, Axis] = {v: k for k, v in to_anticlockwise.items()} to_opposite: dict[Twist, Twist] = {**to_anticlockwise, **to_clockwise} @dataclass(frozen=True) class Corner: col: Axis rot: CornerRotState def __repr__(self) -> str: return f"{self.col}{self.rot}" MidRotState = Literal["Y", "O", "R"] @dataclass(frozen=True) class Middle: col: Literal["G", "O", "B", "R", "Y"] rot: MidRotState def __repr__(self) -> str: return f"{self.col}{self.rot}" @dataclass(frozen=True) class Skewb: """Represents a Rubics cube variant sold as 'Qiyi Twisty Skewb'.""" # order: vertically below Axis primary color top: Tuple[Corner, Corner, Corner, Corner] bot: Tuple[Corner, Corner, Corner, Corner] mids: Tuple[Middle, Middle, Middle, Middle, Middle] # def __post_init__(self): # """Check basic constraints such as the correct number of colors""" def assert_valid(self): assert Counter( corner.col for corners in [self.top, self.bot] for corner in corners ) == {col: 2 for col in AXES} assert Counter(mid.col for mid in self.mids) == Counter("RBOGY") # forbidden rotations assert self.mids[0].rot != "R" assert self.mids[1].rot != "O" assert self.mids[2].rot != "R" assert self.mids[3].rot != "O" assert self.mids[4].col == "Y" or self.mids[4].rot != "Y" G0 = Corner("G", 0) G1 = Corner("G", 1) G2 = Corner("G", 2) O0 = Corner("O", 0) O1 = Corner("O", 1) O2 = Corner("O", 2) B0 = Corner("B", 0) B1 = Corner("B", 1) B2 = Corner("B", 2) R0 = Corner("R", 0) R1 = Corner("R", 1) R2 = Corner("R", 2) GR = Middle("G", "R") OR = Middle("O", "R") BR = Middle("B", "R") RR = Middle("R", "R") GO = Middle("G", "O") OO = Middle("O", "O") BO = Middle("B", "O") RO = Middle("R", "O") GY = Middle("G", "Y") OY = Middle("O", "Y") BY = Middle("B", "Y") RY = Middle("R", "Y") YY = Middle("Y", "Y") def rotate_mid_about_W(m: Middle) -> Middle: return Middle( col=m.col, rot="Y" if m.rot == "Y" else ("R" if m.rot == "O" else "O") ) def rotate_everything_about_W(s: Skewb) -> Skewb: """Clockwise rotation when looking down upon white.""" return Skewb( top=(s.top[-1],) + s.top[:-1], bot=(s.bot[-1],) + s.bot[:-1], mids=( rotate_mid_about_W(s.mids[3]), rotate_mid_about_W(s.mids[0]), rotate_mid_about_W(s.mids[1]), rotate_mid_about_W(s.mids[2]), rotate_mid_about_W(s.mids[4]), ), ) solved_skewb = Skewb( top=(G0, O0, B0, R0), bot=(G0, O0, B0, R0), mids=(GY, OY, BY, RY, YY), ) @pytest.mark.parametrize("s", [solved_skewb]) def test_rotate_everything_about_W(s: Skewb): ss = [s] for i in range(4): ss.append(rotate_everything_about_W(ss[-1])) assert ss[0] == ss[-1] assert all(ss[0] != q for q in ss[1:-1]) def test_axes(): assert AXES[0] == "G" type CornerRotPermutation = dict[CornerRotState, CornerRotState] BOT_LEFT_TO_TOP: CornerRotPermutation = {0: 2, 1: 0, 2: 1} TOP_TO_BOT_RIGHT: CornerRotPermutation = {0: 2, 1: 0, 2: 1} ROTATE_CORNER_CLOCKWISE: CornerRotPermutation = {0: 1, 1: 2, 2: 0} BOT_RIGHT_TO_BOT_LEFT: CornerRotPermutation = {0: 2, 1: 0, 2: 1} @pytest.mark.parametrize( "p", [BOT_LEFT_TO_TOP, TOP_TO_BOT_RIGHT, ROTATE_CORNER_CLOCKWISE, BOT_RIGHT_TO_BOT_LEFT], ) def test_rotation_permutations(p: CornerRotPermutation): assert set(p.keys()) == {0, 1, 2} assert set(p.values()) == {0, 1, 2} MID_DIR_INCREMENT: dict[MidRotState, MidRotState] = {"Y": "O", "R": "Y", "O": "R"} def clockwise_twist(s: Skewb, axis: Axis) -> Skewb: """Applies a clockwise twist to the axis. Coordinate frame is define as follows: - axis goes through the `top` of a solved cube of a given color - the axis-top is on top and faces away from the observer twisting it, - clockwise is twisting the bottom when looking up from diagonally-below the skewb """ rot_before, rot_after = {"G": (0, 0), "O": (3, 1), "B": (2, 2), "R": (1, 3)}[axis] for _ in range(rot_before): s = rotate_everything_about_W(s) s = Skewb( top=( s.top[0], s.top[1], Corner((c := s.bot[0]).col, BOT_LEFT_TO_TOP[c.rot]), s.top[3], ), bot=( Corner((c := s.bot[2]).col, BOT_RIGHT_TO_BOT_LEFT[c.rot]), s.bot[1], Corner((c := s.top[2]).col, TOP_TO_BOT_RIGHT[c.rot]), Corner((c := s.bot[3]).col, ROTATE_CORNER_CLOCKWISE[c.rot]), ), mids=( s.mids[0], s.mids[1], Middle( (m := s.mids[3]).col, "Y" if m.col == "Y" else MID_DIR_INCREMENT[m.rot], ), Middle( (m := s.mids[4]).col, "Y" if m.col == "Y" else MID_DIR_INCREMENT[m.rot], ), Middle( (m := s.mids[2]).col, "Y" if m.col == "Y" else ("O" if m.rot == "Y" else "R"), ), ), ) for _ in range(rot_after): s = rotate_everything_about_W(s) return s @pytest.mark.parametrize( "axis, want", [ ( "O", Skewb( top=(G0, O0, B0, O2), bot=(G1, R2, B0, R2), mids=(YY, OY, BY, GR, RR) ), ) ], ) def test_clockwise_twist_from_solved(axis: Axis, want: Skewb): want.assert_valid() assert clockwise_twist(solved_skewb, axis) == want def anticlockwise_twist(s: Skewb, twist: Axis) -> Skewb: return clockwise_twist(clockwise_twist(s, twist), twist) @pytest.mark.parametrize("start", [solved_skewb]) @pytest.mark.parametrize("axis", AXES) def test_clockwise_twist(start: Skewb, axis: Axis): assert anticlockwise_twist(clockwise_twist(start, axis), axis) == start assert clockwise_twist(anticlockwise_twist(start, axis), axis) == start ss = [start] for i in range(3): ss.append(clockwise_twist(ss[-1], axis)) assert len(ss) == 4 assert ss[0] == ss[-1] assert all([ss[0] != q for q in ss[1:-1]]) for s in ss: assert anticlockwise_twist(clockwise_twist(s, axis), axis) == s assert clockwise_twist(anticlockwise_twist(s, axis), axis) == s def apply_twist(start: Skewb, twist: Twist) -> Skewb: if twist in AXES: return clockwise_twist(start, twist) return anticlockwise_twist(start, to_clockwise[twist]) def apply_opposite(s: Skewb, twist: Twist) -> Skewb: return apply_twist(s, to_opposite[twist]) def apply_twists(start: Skewb, twists: list[Twist]) -> Skewb: return reduce(apply_twist, twists, start) def instructions(twists: list[Axis]) -> str: out = "" axis = "R" for twist in twists: while axis != twist: axis = AXES[(1 + AXES.index(axis)) % 4] out += "L" out += "." out = out.replace("..", ":") out = out.replace("LLL", "J") return out def breadth_first_search( start: Skewb, is_end: Callable[[Skewb], bool], max_steps: int = 2000000, bidirectional_fallback_threshold: int | None = None, ) -> list[Twist] | None: start.assert_valid() if is_end(start): return [] q = deque([start]) # what action got us to this point skewb_to_twist: dict[Skewb, Twist | None] = {skewb: None for skewb in q} def get_path(end: Skewb) -> list[Twist]: out = [] s = end while twist := skewb_to_twist[s]: out.append(twist) s = apply_opposite(s, twist) out.reverse() return out step_count = 0 while q and max_steps > 0: if max_steps % 1000 == 0 and step_count: print(".", end="", flush=True) if ( bidirectional_fallback_threshold is not None and step_count > bidirectional_fallback_threshold ): return bidirectional_search(start, max_steps) step_count += 1 parent = q.popleft() for twist in TWISTS: child = apply_twist(parent, twist) if child in skewb_to_twist: continue skewb_to_twist[child] = twist if is_end(child): return get_path(child) q.append(child) return None def test_breadth_first_search(): for twist in TWISTS: assert breadth_first_search( apply_opposite(solved_skewb, twist), is_end=lambda s: s == solved_skewb ) == [twist] def print_path(path: list[Axis]): x = start print(f"S -> {x}") for twist in path: x = clockwise_twist(x, twist) print(f"{twist} -> {x}") def bidirectional_search( start: Skewb, max_steps: int, end: Skewb = solved_skewb ) -> list[Twist] | None: start.assert_valid() q = deque([start]) q2 = deque([end]) # what action got us to this point skewb_to_twist: dict[Skewb, Twist | None] = {skewb: None for skewb in q} skewb_to_twist2: dict[Skewb, Twist | None] = {skewb: None for skewb in q2} def get_path(meet: Skewb) -> list[Twist]: path = [] s = meet while twist := skewb_to_twist[s]: path.append(twist) s = apply_opposite(s, twist) path.reverse() s = meet while twist := skewb_to_twist2[s]: path.append(twist) s = apply_twist(s, twist) return path def instructions(twists: list[Axis]) -> str: out = "" axis = "R" for twist in twists: while axis != twist: axis = AXES[(1 + AXES.index(axis)) % 4] out += "L" out += "." out = out.replace("..", ":") out = out.replace("LLL", "J") return out def on_meet(meet: Skewb) -> list[Twist]: path = get_path(meet) assert apply_twists(start, path) == end return path # print(f"{heuristic_list(end)=} {''.join(path)=} {instructions(path)=}") # return while q and max_steps > 0: max_steps -= 1 if max_steps % 1000 == 0: print(".", end="", flush=True) parent2 = q2.popleft() for twist in TWISTS: child = apply_opposite(parent2, twist) if child in skewb_to_twist2: continue skewb_to_twist2[child] = twist if child in skewb_to_twist: return on_meet(child) q2.append(child) parent = q.popleft() for twist in TWISTS: child = apply_twist(parent, twist) if child in skewb_to_twist: continue skewb_to_twist[child] = twist if child in skewb_to_twist2: return on_meet(child) q.append(child) print(f"{len(skewb_to_twist)}") return None def heuristic(got: Skewb) -> float: return sum(heuristic_list(got)) def heuristic_list(got: Skewb) -> list[float]: out = [] want = solved_skewb out.append(10000 * (got.top[0] == want.top[0])) out.append(1000 * (got.bot[3] == want.bot[3])) out.append(1000 * (got.mids[3] == want.mids[3])) out.append(100 * (got.top[1] == want.top[1])) out.append(100 * (got.mids[0] == want.mids[0])) out.append(100 * (got.bot[0] == want.bot[0])) out.append(100 * (got.mids[4] == want.mids[4])) for c_got, c_want in zip(got.top + got.bot, want.top + want.bot): out.append(1 * (c_got.col == c_want.col) + 2 * (c_got.rot == c_want.rot)) for m_got, m_want in zip(got.mids, want.mids): out.append(1 * (m_got.rot == m_want.rot) + 4 * (m_got.col == m_want.col)) return out def get_heuristic_matches(s: Skewb) -> List[int]: """ format: [0:4] top [4:8] bot [8:13] mids """ out = [ int((c_got.col == c_want.col) and (c_got.rot == c_want.rot)) for c_got, c_want in zip(s.top + s.bot, solved_skewb.top + solved_skewb.bot) ] out.extend( int((m_got.rot == m_want.rot) and (m_got.col == m_want.col)) for m_got, m_want in zip(s.mids, solved_skewb.mids) ) return out def element_multiply(a: list[int], b: list[int]) -> list[int]: assert len(a) == len(b) return [x * y for x, y in zip(a, b)] def test_heuristic(): assert heuristic(apply_twist(solved_skewb, "O")) < heuristic(solved_skewb) def random_skewb(seed: int = 4, twists: int = 20) -> Skewb: return apply_twists(solved_skewb, random_skewb_twists(seed, twists)) def random_skewb_twists(seed: int = 4, twists: int = 20) -> list[Twist]: out: list[Twist] = [] rng = random.Random(seed) ax_i = rng.randint(0, 3) for _ in range(twists): twist = AXES[ax_i] if rng.getrandbits(1): twist = to_opposite[twist] out.append(twist) ax_i = (ax_i + rng.randint(1, 3)) % len(AXES) return out def double_clockwise_to_anticlockwise(twists: list[Axis]) -> list[Twist]: i = 0 n = len(twists) out: list[Twist] = [] while i < n: twist = twists[i] if i < n - 1 and twist == twists[i + 1]: out.append(to_anticlockwise[twist]) i += 2 else: out.append(twist) i += 1 return out @pytest.mark.parametrize( "twists, want", [ ("RBOG", "RBOG"), ("RRBBOOGG", "rbog"), ("RRR", "rR"), ("RROR", "rOR"), ("RORR", "ROr"), ], ) def test_double_clockwise_to_anticlockwise(twists: list[Axis], want: list[Twist]): assert double_clockwise_to_anticlockwise(twists) == list(want) def test_random_skewb(): twist_count = 50 twists = random_skewb_twists(twists=twist_count) assert len(twists) == twist_count def shelve_it(file_name): d = shelve.open(file_name) def decorator(func): def new_func(*args, **kwargs): key = str(args) + str(kwargs) if key not in d: d[key] = func(*args, **kwargs) return d[key] return new_func return decorator def get_paths_from_heuristic( start: Skewb, heuristic_permutation: list[int] ) -> list[list[Twist]]: out: list[list[Twist]] = [] s = start mask = [0 for _ in heuristic_permutation] total_path_length = 0 for heuristic_i in heuristic_permutation: mask[heuristic_i] = 1 def step_finished(candidate: Skewb) -> bool: matches = get_heuristic_matches(candidate) assert len(matches) == len(mask) return all(match >= m for match, m in zip(matches, mask)) print(f"{mask=} {s=}", end=" ") if heuristic_i == len(heuristic_permutation) - 3: # print("going bidirectional now.") path = bidirectional_search(s, max_steps=200000) else: path = breadth_first_search( s, step_finished, bidirectional_fallback_threshold=20000 ) print(f" {path=}") if path is None: raise ValueError("oh no! solver could not find solution") out.append(path) s = apply_twists(s, path) total_path_length += len(path) return out # def get_total_path_length(start: Skewb, heuristic_permutation: list[int]) -> int: # return sum(len(p) for p in get_paths_from_heuristic(start, heuristic_permutation)) # close_to_wrongly_solved = Skewb(top=(R0, B0, O0, G0), bot=(B0, O0, G0, R0), mids=(BY, GRB, ORG, RY, YY)) near_end = Skewb(top=(O0, B0, R0, G2), bot=(B0, R0, G1, O0), mids=(BY, RY, GY, OY, YY)) start = near_end # start = Skewb(top=(O0, B0, R1, G1), bot=(B0, R2, G2, O0), mids=(BY, RY, GY, OY, YY)) HURISTIC_PERMUTATION_LENGTH = 4 + 4 + 5 def quadratic_mean(values: list[float]) -> float: return sqrt(sum(x * x for x in values) / len(values)) def score_fn(values: list[float]) -> float: return -sum(2**i * v for i, v in enumerate(sorted(values))) @shelve_it("skewb_solver.evaluate_permutation.shelve.sqlite") def evaluate_permutation( heuristic_permutation: list[int], seed=4, sample_size: int = 10 ) -> float: scores = [] for i in range(sample_size): paths = get_paths_from_heuristic(start, heuristic_permutation) score = score_fn([len(p) for p in paths]) print(f"{score=}") scores.append(score) return sum(scores) / len(scores) def evaluate_all_1_swaps(hp: list[int]): for i in range(len(hp)): for j in range(i): hp[i], hp[j] = hp[j], hp[i] evaluate_permutation(hp, sample_size=22) hp[i], hp[j] = hp[j], hp[i] def evaluate_all_1_swaps_except_first(hp: list[int]): for i in range(len(hp)): for j in range(1, i): hp[i], hp[j] = hp[j], hp[i] evaluation = evaluate_permutation(hp, sample_size=100) print(f"{hp=} {evaluation=}") hp[i], hp[j] = hp[j], hp[i] if __name__ == "__main__": hp = top_bot_mids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] hp = top_down = [0, 1, 2, 3, 8, 9, 10, 11, 12, 4, 5, 6, 7] # evaluate_all_1_swaps(top_down) evaluate_all_1_swaps_except_first(hp) start = Skewb((G0, O0, B0, R0), (G0, R1, B2, O2), (BY, OY, RO, GR, YY)) # print(bidirectional_search(start, max_steps=1000000)) # print(get_paths_from_heuristic(start, hp))