猫でもわかるWeb開発・プログラミング

本業エンジニアリングマネージャー。副業Webエンジニア。Web開発のヒントや、副業、日常生活のことを書きます。

AOJ 1250: Leaky Cryptography

問題概要

9つの数字が与えられる。最初の8つが暗号化された16進数の数字、9つ目がチェックサムである。しかし、この暗号化には脆弱性がある。これらの暗号は8つの暗号化される前の数字を C_1, C_2, ... C_8、暗号化のキーをK、チェックサムをSとすると。

{(C_1 ^ K) + (C_2 ^ K) + ... + (C_8 ^ K)} % 232 = SK

となる。C_1, ... C_8 と S が与えられとKが分かってしまうのである。このKを求めよ。

方針

問題の意味さえ分かってしまえばやるだけである。 わかりやすいように、2進数で考える。また、数字2つとチェックサムだけを考える。

例えば C_1 = 0001, C_2 = 0011 であったとする。この時、S = 0111 であったとする。

まず、1桁目(一番下の桁)を見てみると、1 + 1 でSの一桁目は1にならないので、これを合わせるためにKの1桁目は1となる。C_1, C_2 は現状それぞれ 0000, 0010 であったことになる。

次の2桁目はは矛盾がない。3桁目は矛盾している。0000と0010の和をとったら3桁目は1にならないので、K3桁目は1となる。C_1, C_2 は現状それぞれ 0100, 0110 となる。

続いて4桁目を見てみると、0100 + 0110 で4桁目は1になるはずなので矛盾している…

といった具合にKを求めていく。ビット操作に慣れてないと難しいかもしれないし、説明も難解になってしまった。さらにコードもシンプルなわりに難解になる傾向になるのでなれるまで難しい。ビット操作の問題を解いてなれるしかない気がする。

また、Javaはlong型を使うのが無難。32ビットの整数なので、intだと負の補数表現などの影響でうまくいかないことがある。

ソースコード

class Main extends MyUtil{
    static int n;
    static long mod = 1L << 32;
    
    public static void main(String[] args) throws Exception{
        
        Scanner sc = new Scanner(new InputStreamReader(System.in));
        int n = sc.nextInt();
        
        for(int i = 0; i < n; i++){
            // 入力
            long[] c = new long[9];
            for(int j = 0; j < 9; j++){
                c[j] = Long.parseLong(sc.next(), 16);
            }
            // ここまで入力
            
            // ビットを合わせていく操作
            long d = 1;
            long k = 0;
            for(int j = 0; j < 32; j++){
                long a = (sum(c) / d) % 2;
                long b = (c[8] / d) % 2;
                
                if(a != b){
                    k += d; 
                    xor(c, d);
                }
                d *= 2;
            }
            // 16進数に変換して出力
            System.out.println(Long.toHexString(k));
        }
    }
    
    // 配列とkのXORをとっていく
    static void xor(long[] arr, long k){
        for(int i = 0; i < 8; i++){
            arr[i] = arr[i] ^ k;
        }
    }
    
    // 配列の中身の和を2^32で割ったあまりを出す
    static long sum(long[] arr){
        long sum = 0;
        for(int i = 0; i < 8; i++){
            sum += arr[i];
            sum %= mod;
        }
        return sum;
    }
}