diff --git a/keras_cv/models/stable_diffusion/stable_diffusion.py b/keras_cv/models/stable_diffusion/stable_diffusion.py
index 3b71b40..a961502 100644
--- a/keras_cv/models/stable_diffusion/stable_diffusion.py
+++ b/keras_cv/models/stable_diffusion/stable_diffusion.py
@@ -26,6 +26,7 @@ import math
 import numpy as np
 import tensorflow as tf
 from tensorflow import keras
+from keras import backend as K
 
 from keras_cv.models.stable_diffusion.clip_tokenizer import SimpleTokenizer
 from keras_cv.models.stable_diffusion.constants import _ALPHAS_CUMPROD
@@ -48,6 +49,7 @@ class StableDiffusionBase:
         img_height=512,
         img_width=512,
         jit_compile=False,
+        precision="fp32",
     ):
         # UNet requires multiples of 2**7 = 128
         img_height = round(img_height / 128) * 128
@@ -63,6 +65,7 @@ class StableDiffusionBase:
         self._tokenizer = None
 
         self.jit_compile = jit_compile
+        self.to_fp16 = (precision == "fp16")
 
     def text_to_image(
         self,
@@ -205,18 +208,21 @@ class StableDiffusionBase:
 
         # Iterative reverse diffusion stage
         timesteps = tf.range(1, 1000, 1000 // num_steps)
+        t_embs_lst = self._get_timesteps_embedding(timesteps, batch_size)
+        contexts = tf.concat((unconditional_context, context), 0)
+
         alphas, alphas_prev = self._get_initial_alphas(timesteps)
         progbar = keras.utils.Progbar(len(timesteps))
         iteration = 0
         for index, timestep in list(enumerate(timesteps))[::-1]:
             latent_prev = latent  # Set aside the previous latent vector
-            t_emb = self._get_timestep_embedding(timestep, batch_size)
-            unconditional_latent = self.diffusion_model.predict_on_batch(
-                [latent, t_emb, unconditional_context]
-            )
-            latent = self.diffusion_model.predict_on_batch(
-                [latent, t_emb, context]
-            )
+            latents = tf.concat((latent, latent), 0)
+            t_embs = t_embs_lst[index]
+
+            pred_latent = self.diffusion_model.predict_on_batch(
+                [latents, t_embs, contexts])
+            unconditional_latent, latent = tf.split(pred_latent, 2)
+
             latent = unconditional_latent + unconditional_guidance_scale * (
                 latent - unconditional_latent
             )
@@ -458,6 +464,23 @@ class StableDiffusionBase:
             self._tokenizer = SimpleTokenizer()
         return self._tokenizer
 
+    def _get_timesteps_embedding(
+        self, timesteps, batch_size, dim=320, max_period=10000
+    ):
+        half = dim // 2
+        freqs = tf.math.exp(
+            -math.log(max_period) * tf.range(0, half, dtype=tf.float32) / half
+        )
+        # timesteps shape: [num_steps]
+        args = tf.cast(tf.reshape(timesteps, [-1, 1]), dtype=tf.float32) * freqs
+        # embeddings shape:(steps, half)
+        embeddings = tf.concat([tf.math.cos(args), tf.math.sin(args)], axis=1)
+        embeddings = tf.expand_dims(embeddings, axis=1)
+        if self.to_fp16:
+            embeddings = tf.cast(embeddings, tf.float16)
+        #  2 is to concatenate the embedding of two forward pass
+        return tf.repeat(embeddings, batch_size * 2, axis=1)
+
     def _get_timestep_embedding(
         self, timestep, batch_size, dim=320, max_period=10000
     ):
@@ -468,6 +491,8 @@ class StableDiffusionBase:
         args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
         embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
         embedding = tf.reshape(embedding, [1, -1])
+        if self.to_fp16:
+            embedding = tf.cast(embedding, tf.float16)
         return tf.repeat(embedding, batch_size, axis=0)
 
     def _get_initial_alphas(self, timesteps):
@@ -478,14 +503,25 @@ class StableDiffusionBase:
 
     def _get_initial_diffusion_noise(self, batch_size, seed):
         if seed is not None:
-            return tf.random.stateless_normal(
-                (batch_size, self.img_height // 8, self.img_width // 8, 4),
-                seed=[seed, seed],
-            )
+            if self.to_fp16:
+                return tf.random.stateless_normal(
+                    (batch_size, self.img_height // 8, self.img_width // 8, 4),
+                    seed=[seed, seed], dtype=tf.float16
+                )
+            else:
+                return tf.random.stateless_normal(
+                    (batch_size, self.img_height // 8, self.img_width // 8, 4),
+                    seed=[seed, seed],
+                )
         else:
-            return tf.random.normal(
-                (batch_size, self.img_height // 8, self.img_width // 8, 4)
-            )
+            if self.to_fp16:
+                return tf.random.normal(
+                        (batch_size, self.img_height // 8, self.img_width // 8, 4), dtype=tf.float16
+                    )
+            else:
+                return tf.random.normal(
+                    (batch_size, self.img_height // 8, self.img_width // 8, 4)
+                )
 
     @staticmethod
     def _get_pos_ids():
@@ -543,8 +579,9 @@ class StableDiffusion(StableDiffusionBase):
         img_height=512,
         img_width=512,
         jit_compile=False,
+        precision="fp32",
     ):
-        super().__init__(img_height, img_width, jit_compile)
+        super().__init__(img_height, img_width, jit_compile, precision)
         print(
             "By using this model checkpoint, you acknowledge that its usage is "
             "subject to the terms of the CreativeML Open RAIL-M license at "
@@ -558,7 +595,16 @@ class StableDiffusion(StableDiffusionBase):
         needs to be modified.
         """
         if self._text_encoder is None:
-            self._text_encoder = TextEncoder(MAX_PROMPT_LENGTH)
+            if self.to_fp16:
+                self._text_encoder_fp32 = TextEncoder(MAX_PROMPT_LENGTH)
+                weights_fp32 = self._text_encoder_fp32.get_weights()
+                K.set_floatx('float16')
+                weights_fp16 = [w.astype(K.floatx()) for w in weights_fp32]
+                self._text_encoder = TextEncoder(
+                    MAX_PROMPT_LENGTH, download_weights=False)
+                self._text_encoder.set_weights(weights_fp16)
+            else:
+                self._text_encoder = TextEncoder(MAX_PROMPT_LENGTH)
             if self.jit_compile:
                 self._text_encoder.compile(jit_compile=True)
         return self._text_encoder
@@ -626,8 +672,9 @@ class StableDiffusionV2(StableDiffusionBase):
         img_height=512,
         img_width=512,
         jit_compile=False,
+        precision="fp32",
     ):
-        super().__init__(img_height, img_width, jit_compile)
+        super().__init__(img_height, img_width, jit_compile, precision)
         print(
             "By using this model checkpoint, you acknowledge that its usage is "
             "subject to the terms of the CreativeML Open RAIL++-M license at "
@@ -641,7 +688,16 @@ class StableDiffusionV2(StableDiffusionBase):
         needs to be modified.
         """
         if self._text_encoder is None:
-            self._text_encoder = TextEncoderV2(MAX_PROMPT_LENGTH)
+            if self.to_fp16:
+                self._text_encoder_fp32 = TextEncoderV2(MAX_PROMPT_LENGTH)
+                weights_fp32 = self._text_encoder_fp32.get_weights()
+                K.set_floatx('float16')
+                weights_fp16 = [w.astype(K.floatx()) for w in weights_fp32]
+                self._text_encoder = TextEncoderV2(
+                    MAX_PROMPT_LENGTH, download_weights=False)
+                self._text_encoder.set_weights(weights_fp16)
+            else:
+                self._text_encoder = TextEncoderV2(MAX_PROMPT_LENGTH)
             if self.jit_compile:
                 self._text_encoder.compile(jit_compile=True)
         return self._text_encoder
