diff --git a/rtiww/src/main.cpp b/rtiww/src/main.cpp index 10a54f0..60e281d 100644 --- a/rtiww/src/main.cpp +++ b/rtiww/src/main.cpp @@ -46,9 +46,9 @@ int main() { hittable_list world; auto material_ground = make_shared(color(0.8, 0.8, 0.0)); - auto material_center = make_shared(1.5); + auto material_center = make_shared(color(0.1, 0.2, 0.5)); auto material_left = make_shared(1.5); - auto material_right = make_shared(color(0.8, 0.6, 0.2), 1.0); + auto material_right = make_shared(color(0.8, 0.6, 0.2), 0.0); world.add( make_shared(point3(0.0, -100.5, -1.0), 100.0, material_ground)); diff --git a/rtiww/src/material.h b/rtiww/src/material.h index c3837e5..e026201 100644 --- a/rtiww/src/material.h +++ b/rtiww/src/material.h @@ -57,14 +57,23 @@ public: virtual bool scatter(const ray &r_in, const hit_record &rec, color &attenuation, ray &scattered) const override { - attenuation = color(1.0, 1.0, 1.0); - double refraction_ratio = rec.front_face ? (1.0 / ir) : ir; + attenuation = color(1.0, 1.0, 1.0); + double refraction_ratio = rec.front_face ? (1.0/ir) : ir; - vec3 unit_direction = unit_vector(r_in.direction()); - vec3 refracted = refract(unit_direction, rec.normal, refraction_ratio); + vec3 unit_direction = unit_vector(r_in.direction()); + double cos_theta = fmin(dot(-unit_direction, rec.normal), 1.0); + double sin_theta = sqrt(1.0 - cos_theta*cos_theta); - scattered = ray(rec.p, refracted); - return true; + bool cannot_refract = refraction_ratio * sin_theta > 1.0; + vec3 direction; + + if (cannot_refract) + direction = reflect(unit_direction, rec.normal); + else + direction = refract(unit_direction, rec.normal, refraction_ratio); + + scattered = ray(rec.p, direction); + return true; } public: