Skip to content

Commit 3c37470

Browse files
authored
Better example code for the subsection "1. Replaced a variable with a tensor"
# Better example code for the subsection "1. Replaced a variable with a tensor" **Position:** "1. Replaced a variable with a tensor" subsection **Link:** https://www.tensorflow.org/guide/autodiff#1_replaced_a_variable_with_a_tensor **Condition:** The original example is very concise. However, it took me a considerable amount of time to modify the code and test it to understand the example fully. I think there are two points that confused me. Point 1: The derivative of `y = x + 1` is a constant scalar `1`. This will cause no value change as the number of epoch increases. We can't obviously see the change in the value. Point 2: The comment directly gives the correct code `# This should be x.assign_add(1)` without the reason. **Suggestion:** I have modified the example code as shown below. 1. The formula has been changed from `y = x + 1` to `y = x**2`. The derivative is now `2x`, and we can obviously see the value change corresponding to epoch. 2. A comment "The `tf.Variable` has been inadvertently replaced with a `tf.Tensor`." has been added. 3. A epoch indicator is added. `print("epoch:", epoch)` 4. The number of epoch is changed from 2 to 3. REF: https://github.com/HsienChing/ML_DL_project_State_Estimation_of_Li-ion_Batteries/blob/main/other/Issues_in_TensorFlow_official_doc_Introduction_to_gradients_and_automatic_differentiation.ipynb
1 parent 7c36502 commit 3c37470

1 file changed

Lines changed: 5 additions & 3 deletions

File tree

site/en/guide/autodiff.ipynb

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -799,12 +799,14 @@
799799
"source": [
800800
"x = tf.Variable(2.0)\n",
801801
"\n",
802-
"for epoch in range(2):\n",
802+
"for epoch in range(3):\n",
803803
" with tf.GradientTape() as tape:\n",
804-
" y = x+1\n",
804+
" y = x**2\n",
805805
"\n",
806+
" print("epoch:", epoch)\n,
806807
" print(type(x).__name__, \":\", tape.gradient(y, x))\n",
807-
" x = x + 1 # This should be `x.assign_add(1)`"
808+
" x = x + 1 # The `tf.Variable` has been inadvertently replaced with a `tf.Tensor`."
809+
" # This should be `x.assign_add(1)`."
808810
]
809811
},
810812
{

0 commit comments

Comments
 (0)